Facebook
From pussy hounter 69, 5 Years ago, written in Plain Text.
Embed
Download Paste or View Raw
Hits: 210
  1. import math
  2.  
  3. import sklearn
  4.  
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import pandas as pd
  8. import seaborn as sns
  9. from sklearn import metrics
  10. from sklearn.datasets import load_boston
  11. from sklearn.linear_model import LinearRegression
  12. from sklearn.model_selection import train_test_split
  13.  
  14. # Zad 1
  15. print("Zad 1")
  16. boston_dataset = load_boston()
  17.  
  18. data_frame = pd.DataFrame(data=boston_dataset['data'], columns=boston_dataset['feature_names'])
  19. data_frame['MEDV'] = pd.Series(boston_dataset['target'])
  20.  
  21. print(data_frame.head(10))
  22. print(data_frame.tail(10))
  23.  
  24. # Zad 2
  25. print("Zad 2")
  26. data_frame.info()
  27. # a - 506
  28. # b - float64
  29. # c - nie
  30.  
  31. # Zad 3
  32. print("Zad 3")
  33. describe = data_frame.describe()
  34. print(describe)
  35. # a - sredni=3.593761 std=8.596783
  36. # b - max=50.000000  min=5.000000
  37. # c - 12.653063
  38.  
  39. # Zad 4
  40. print("Zad 4")
  41. sns.distplot(data_frame.MEDV)
  42. plt.show()
  43.  
  44. # Zad 5
  45. print("Zad 5")
  46. corr_matrix = data_frame.corr().round(2)
  47. sns.heatmap(corr_matrix, annot=True)
  48.  
  49. # a - RM, ZN, B
  50. # b - LSTAT
  51. # c - TAX do RAD oraz same dla siebie - 1
  52.  
  53. sns.lmplot('MEDV', 'RM', data=corr_matrix)
  54. sns.lmplot('MEDV', 'LSTAT', data=corr_matrix)
  55. plt.show()
  56.  
  57. # Zad 6
  58. print("Zad 6")
  59. x = data_frame[['RM', 'B', 'ZN']]
  60. y = data_frame[['MEDV']]
  61. X_train, X_test, Y_train, Y_test = train_test_split(x, y, test_size=0.2)
  62.  
  63. # Zad 7
  64. print("Zad 7")
  65. lin = LinearRegression()
  66. lin.fit(X_train, Y_train)
  67.  
  68. Y_pred = lin.predict(X_test)
  69. Y_pred_ = lin.predict(X_train)
  70.  
  71. plt.scatter(Y_test, Y_pred)
  72. plt.title('testowy')
  73. plt.show()
  74. plt.title('treningowy')
  75. plt.scatter(Y_train, Y_pred_)
  76. plt.show()
  77.  
  78. # Zad 8
  79. print("Zad 8")
  80. print('treningowy')
  81. print('RMSE: {}'.format(math.sqrt(metrics.mean_squared_error(Y_train, Y_pred_))))
  82. print('MAE: {}'.format(metrics.mean_absolute_error(Y_train, Y_pred_)))
  83.  
  84. print('testowy')
  85. print('RMSE: {}'.format(math.sqrt(metrics.mean_squared_error(Y_test, Y_pred))))
  86. print('MAE: {}'.format(metrics.mean_absolute_error(Y_test, Y_pred)))