train_x = aug_x
train_y = aug_y
train_x = preprocessing.scale(train_x)
reducer = KernelPCA(n_components=300, kernel='rbf')
X_embeded = reducer.fit_transform(train_x)
print(X_embeded.shape)
cluster_num = 1
clustering = KMeans(n_clusters=cluster_num).fit(X_embeded)
for i in range(cluster_num):
print(np.count_nonzero(clustering.labels_ == i))
estimator_num = 200
models = []
for i in range(cluster_num):
trX, tsX, trY, tsY = train_test_split(train_x[clustering.labels_ == i], train_y[clustering.labels_ == i], test_size = 0.1, shuffle=True)
print("Training with", trX.shape[0], "Samples\nTesting with", tsX.shape[0])
test_data = tsX
train_data = trX
rf = RandomForestRegressor(n_estimators = estimator_num)
#rf = RandomForestRegressor()
rf.fit(train_data, trY)
predicted_rf = rf.predict(test_data)
predicted_rf_train = rf.predict(train_data)
#gb = GradientBoostingRegressor(learning_rate=0.01, n_estimators=estimator_num)
#gb.fit(train_data.flatten().reshape(-1,1), trY.flatten())
#predicted_gb = gb.predict(test_data.flatten().reshape(-1, 1))
models.append(rf)
#models.append(gb)
actual = tsY.flatten() #melt t1 dataset (ground truth)
predicted = predicted_rf.flatten()#melt your prediction
print("RF Result =",mse(predicted,actual))#returns mse result for two melted matrices
#print("GB Result =",mse(predicted_gb.flatten(),actual))#returns mse result for two melted matrices
print("RF Train Result", mse(predicted_rf_train.flatten(),trY.flatten()))#returns mse result for two melted matrices
{"html5":"htmlmixed","css":"css","javascript":"javascript","php":"php","python":"python","ruby":"ruby","lua":"text\/x-lua","bash":"text\/x-sh","go":"go","c":"text\/x-csrc","cpp":"text\/x-c++src","diff":"diff","latex":"stex","sql":"sql","xml":"xml","apl":"apl","asterisk":"asterisk","c_loadrunner":"text\/x-csrc","c_mac":"text\/x-csrc","coffeescript":"text\/x-coffeescript","csharp":"text\/x-csharp","d":"d","ecmascript":"javascript","erlang":"erlang","groovy":"text\/x-groovy","haskell":"text\/x-haskell","haxe":"text\/x-haxe","html4strict":"htmlmixed","java":"text\/x-java","java5":"text\/x-java","jquery":"javascript","mirc":"mirc","mysql":"sql","ocaml":"text\/x-ocaml","pascal":"text\/x-pascal","perl":"perl","perl6":"perl","plsql":"sql","properties":"text\/x-properties","q":"text\/x-q","scala":"scala","scheme":"text\/x-scheme","tcl":"text\/x-tcl","vb":"text\/x-vb","verilog":"text\/x-verilog","yaml":"text\/x-yaml","z80":"text\/x-z80"}