Le code ci-dessous lit un fichier de données csv (Andrew NG ML course ex1 multivariate linear regression exercise) et tente ensuite d'ajuster un modèle linéaire à l'ensemble de données en utilisant le taux d'apprentissage, alpha = 0,01. La descente de gradient consiste à diminuer les paramètres (vecteur thêta) 400 fois (les valeurs de alpha et de num_of_iterations étaient données dans l'énoncé du problème). J'ai essayé une implémentation vectorielle pour obtenir les valeurs optimales des paramètres mais la descente ne converge pas - l'erreur continue d'augmenter.
# Imports
```python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
```
# Model Preparation
## Gradient descent
```python
def gradient_descent(m, theta, alpha, num_of_iterations, X, Y):
# print(m, theta, alpha, num_of_iterations)
for i in range(num_of_iterations):
htheta_vector = np.dot(X,theta)
# print(X.shape, theta.shape, htheta_vector.shape)
error_vector = htheta_vector - Y
gradient_vector = (1/m) * (np.dot(X.T, error_vector)) # each element in gradient_vector corresponds to each theta
theta = theta - alpha * gradient_vector
return theta
```
# Main
```python
def main():
df = pd.read_csv('data2.csv', header = None) #loading data
data = df.values # converting dataframe to numpy array
X = data[:, 0:2]
# print(X.shape)
Y = data[:, -1]
m = (X.shape)[0] # number of training examples
Y = Y.reshape(m, 1)
ones = np.ones(shape = (m,1))
X_with_bias = np.concatenate([ones, X], axis = 1)
theta = np.zeros(shape = (3,1)) # two features, so three parameters
alpha = 0.001
num_of_iterations = 400
theta = gradient_descent(m, theta, alpha, num_of_iterations, X_with_bias, Y) # calling gradient descent
# print('Parameters learned: ' + str(theta))
if __name__ == '__main__':
main()
```
L'erreur :
/home/krish-thorcode/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:8: RuntimeWarning: invalid value encountered in subtract
Valeurs d'erreur pour différentes itérations :
Itération 1 [[-399900.] [-329900.] [-369000.] [-232000.] [-539900.] [-299900.] [-314900.] [-198999.] [-212000.] [-242500.] [-239999.] [-347000.] [-329999.] [-699900.] [-259900.] [-449900.] [-299900.] [-199900.] [-499998.] [-599000.] [-252900.] [-255000.] [-242900.] [-259900.] [-573900.] [-249900.] [-464500.] [-469000.] [-475000.] [-299900.] [-349900.] [-169900.] [-314900.] [-579900.] [-285900.] [-249900.] [-229900.] [-345000.] [-549000.] [-287000.] [-368500.] [-329900.] [-314000.] [-299000.] [-179900.] [-299900.] [-239500.]]
Iteration 2 [[1.60749981e+09] [1.22240841e+09] [1.83373661e+09] [1,08189071e+09] [2,29209231e+09] [1,51666004e+09] [1,17198560e+09] [1,09033113e+09] [1,05440030e+09] [1,14148964e+09] [1,48233053e+09] [1,52807496e+09]. [1.52807496e+09] [1.44402895e+09] [3.42143452e+09] [9.68760976e+08] [1.75723592e+09] [1.00845873e+09] [9.44366284e+08] [1.99332644e+09] [2.31572369e+09] [1.35010833e+09] [1.44257442e+09] [1.22555224e+09] [1,49912323e+09] [2,97220331e+09] [8,40383843e+08] [1,11375611e+09] [1,92992696e+09]. [1.92992696e+09] [1.68078878e+09] [2.01492327e+09] [1.40503327e+09] [7.64040689e+08] [1.55867654e+09] [2.39674784e+09] [1.38370165e+09] [1.09792232e+09] [9.46628911e+08] [1.62895368e+09] [3.22059730e+09] [1.65193796e+09] [1.27127807e+09] [1.70997383e+09] [1.96141565e+09] [9.16755655e+08] [6.50928858e+08] [1.41502023e+09] [9.19107783e+08]]
Itération 3 [[-7.42664624e+12] [-5.64764378e+12] [-8.47145714e+12] [-4,99816153e+12] [-1,05893224e+13] [-7,00660901e+12] [-5.41467917e+12] [-5.03699402e+12] [-4.87109500e+12] [-5.27348843e+12] [-6.84776945e+12] [-7.05955046e+12] [-6.67127611e+12] [-1.58063228e+13] [-4.47576119e+12] [-8.11848565e+12] [-4.65930400e+12] [-4.36280860e+12] [-9.20918360e+12] [-1.06987452e+13] [-6.23711474e+12] [-6.66421140e+12] [-5.66176276e+12] [-6.92542434e+12] [-1,37308096e+13] [-3,88276038e+12] [-5,14641706e+12] [-8.91620784e+12] [-7.76550392e+12] [-9.30801176e+12] [-6,49125293e+12] [-3,52977344e+12] [-7,20074619e+12] [-1.10728954e+13] [-6.39242960e+12] [-5.07229174e+12] [-4.37339793e+12] [-7.52548475e+12] [-1.48779889e+13] [-7.63137769e+12] [-5.87354379e+12] [-7.89963490e+12] [-9.06093321e+12] [-4.23573710e+12] [-3.00737309e+12] [-6.53715005e+12] [-4.24632634e+12]]
Itération 4 [[3.43099835e+16] [2.60912608e+16] [3.91368523e+16] [2,30907512e+16] [4,89210695e+16] [3,23694753e+16] [2,50149995e+16] [2,32701516e+16] [2,25037231e+16] [2,43627199e+16] [3,16356608e+16] [3.26140566e+16] [3.08202877e+16] [7.30228235e+16] [2.06773403e+16] [3,75061770e+16] [2,15252802e+16] [2,01555166e+16] [4,25450367e+16] [4,94265862e+16]. [4.94265862e+16] [2.88145280e+16] [3.07876502e+16] [2.61564888e+16] [3.19944145e+16] [6.34342666e+16] [1.79377661e+16] [2.37756683e+16] [4,11915330e+16] [3,58754545e+16] [4,30016088e+16] [2,99886077e+16] [1,63070200e+16] [2,79377661e+16] [2,37756683e+16]. [1,63070200e+16] [3,32663597e+16] [5,11551035e+16] [2,95320591e+16] [2,34332215e+16]. [2.34332215e+16] [2.02044376e+16] [3.47666027e+16] [6.87340617e+16] [3.52558124e+16] [2.71348846e+16] [3.64951201e+16] [4.18601431e+16] [1,95684650e+16] [1,38936092e+16] [3,02006457e+16] [1,96173860e+16]. [1.96173860e+16]]
Itération 5 [[-1.58506940e+20] [-1.20537683e+20] [-1.80806345e+20] [-1,06675782e+20] [-2,26007951e+20] [-1,49542086e+20] [-1,15565519e+20] [-1,07504585e+20] [-1,03963801e+20] [-1,12552086e+20] [-1,46151974e+20] [-1,50672014e+20] [-1,42385073e+20] [-3,37354413e+20] [-9,55261885e+19] [-1.73272871e+20] [-9.94435428e+19] [-9.31154420e+19] [-1,96551642e+20] [-2,28343362e+20] [-1,33118767e+20] [-1,42234293e+20] [-1,20839027e+20] [-1,47809362e+20] [-2.93056729e+20] [-8.28697695e+19] [-1.09839996e+20] [-1,90298660e+20] [-1,65739180e+20] [-1,98660937e+20] [-1.38542837e+20] [-7.53359691e+19] [-1.53685556e+20] [-2,36328850e+20] [-1,36433652e+20] [-1,08257943e+20] [-9,33414495e+19] [-1,60616452e+20] [-3,17540981e+20] [-1.62876527e+20] [-1.25359067e+20] [-1.68601941e+20] [-1,93387537e+20] [-9,04033523e+19] [-6,41863754e+19] [-1.39522421e+20] [-9.06293597e+19]]
Itération 83 [[-1.09904300e+306] [-8.35774743e+305] [-1,25366087e+306] [-7,39660179e+305] [-1,56707622e+306] [-1,03688320e+306] [-8,01299137e+305] [-7,45406868e+305] [-7,20856058e+305] [-7,80404831e+305] [-1,01337710e+306] [-1,04471781e+306] [-9,87258464e+305] [-2,33912159e+306] [-6.62352000e+305] [-1.20142586e+306] [-6.89513844e+305] [-6,4563655e+305] [-1,36283437e+306] [-1,58326931e+306] [-9,23008472e+305] [-9,86212994e+305] [-8,37864174e+305] [-1,02486897e+306] [-2,03197378e+306] [-5,74595914e+305] [-7,61599955e+305] [-1,31947793e+306] [-1,14918934e+306] [-1,37745963e+306] [-9,60617469e+305] [-5,22358639e+305] [-1,06561287e+306] [-1,63863846e+306] [-9,45992963e+305] [-7,50630445e+305] [-6,47203628e+305] [-1,11366977e+306] [-2,20174077e+306] [-1,12934050e+306] [-8,69204879e+305] [-1,16903893e+306] [-1,34089535e+306] [-6,26831680e+305] [-4,45050460e+305] [-9,67409627e+305] [-6,28398753e+305]]
Iteration84 [[inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf]
[inf.] [inf.] [inf.] [inf.] [inf.] [inf.] [inf.] [inf.] [inf.] [inf.] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf]]]