Clear Filters
Clear Filters

How to set the input arguments of fitrgp function and predict function for Gaussian Process Regression?

14 views (last 30 days)
Hello everyone,
I have implemented Gaussian Process Regression in Python using scikit-learn and its Gaussian Process Regressor. And I want to rewrite it in Matlab.
Target function: y = (x-0.47)^2*sin(3*x)
Firstly, I have two observations x1 = 0, y1 = 0; x2 = 1, y2 = 0.03964061.
With the following code in python I can visualize the result of Gaussian Process Regression, the mean function acts as a line and the values of variance function is larger in the middle and smaller near the two observations, which is reasonable.
Python Code:
import math
import numpy as np
from matplotlib import pyplot as plt
from warnings import catch_warnings
from warnings import simplefilter
from sklearn.gaussian_process import GaussianProcessRegressor
# objective function
def objective(x):
return ((x-0.47)**2 * math.sin(3 * x))
surrogate = GaussianProcessRegressor()
# visualization
def plot(X, y, xsamples, ysamples, yhat, std, new_x, new_y, i):
plt.figure(figsize=(12,6))
plt.plot(X, y, label='real')
plt.scatter(xsamples, ysamples, label='explored samples')
plt.plot(X, yhat, label='gussian process - mean', c='g')
plt.plot(X, yhat+std, label='gussian process - upper/lower bound',c='g',linestyle='--',)
plt.plot(X, yhat-std, c='g',linestyle='--',)
plt.scatter([new_x], [new_y], label='next sample', c='r')
plt.legend()
plt.title(f'Iteration {i}')
plt.show()
return
# grid-based sample of the domain [0,1]
X = np.arange(0, 1, 0.01)
X = X.reshape(-1,1)
# sample the domain without noise
y = np.array([objective(x) for x in X])
y = y.reshape(-1,1)
# sample the domain with noise
ynoise = [objective(x) for x in X]
# find best result
ix = np.argmax(y)
xsamples=np.array([[0],[1]])
ysamples=np.array([objective(x) for x in xsamples])
plt.figure(figsize=(12,6))
plt.plot(X, y)
plt.annotate('Optima',(X[ix], y[ix]))
plt.scatter(xsamples, ysamples)
plt.show()
new_x_vector = np.array([[0.53],[0.91],[0.82],[0.87],[0.87]])
for i in range(5):
surrogate.fit(xsamples, ysamples)
yhat, std=surrogate.predict(X, return_std=True)
std=std.reshape(-1,1)
yhat=yhat.reshape(-1,1)
#step
new_x=new_x_vector[i]
print(new_x)
new_y=objective(new_x)
print(new_y)
plot(X, y, xsamples, ysamples, yhat, std, new_x, new_y, i)
xsamples=np.vstack((xsamples, new_x))
ysamples=np.vstack((ysamples, new_y))
Now I want to implement Gaussian Process Regression in Matlab with the following code using fitrgp function and predict function and visualize the result. The result is not so reasonable as in Python. Firstly, the predicted line does not go through two observations. Secondly, the values of variance function is almost equal at every point, which is not resonable.
I think the reason of this result is that I don't set input arguments of fitrgp function and predict function correctly. Will someone help me out? How to set those input arguments, in order to achieve a same or similar result as in Python?
Matlab Code:
x_target = 0:0.01:1;
y_target = sin(3*x_target).*(x_target-0.47).^2;
figure(2);
plot(x_target,y_target,'DisplayName','Target');
hold on
x = (0:1:1)';
y = sin(3*x).*(x-0.47).^2;
gpr = fitrgp(x,y);
plot(x,y,'bo','MarkerFaceColor','b','DisplayName','Observation');
xtest = linspace(0,1,100)';
[y_predicted,y_sd,y_interval] = predict(gpr,xtest);
plot(xtest,y_predicted,'g','DisplayName','Prediction');
plot(xtest,y_interval(:,1)','g--','DisplayName','Lower');
plot(xtest,y_interval(:,2)','g--','DisplayName','Higher');
legend('show','Location','best');
Thans a lot
  1 Comment
Huy
Huy on 4 Mar 2024
I'm not sure what basis functions GaussianProcessRegressor() uses in Python, but fitrgp() by default uses constant basis functions for approximation. This explains the flat prediction curve due to only 2 data points. You can specify the basis function to quadratic or more. See Input arguments of fitrgp() https://www.mathworks.com/help/stats/fitrgp.html

Sign in to comment.

Answers (1)

Sudarsanan A K
Sudarsanan A K on 27 Oct 2023
Hello Shiqun,
I understand that you are seeking assistance in implementing Gaussian Process Regression in MATLAB to achieve a similar result as your Python code. With the MATLAB code you provided, you notice that the predicted line does not go through two observations and the values of variance function is almost equal at every point. You suspect that incorrect setting of input arguments of "fitrgp" function and predict function cause this behaviour.
On comparing your Python code with your MATLAB code, the main difference lies in how the initial observations are defined and the generation of the target values.
In order to set the initial observations as you did for Python code and generate the target values, you can make some slight modifications to your MATLAB code as follows:
% Objective function
objective = @(x) (x-0.47).^2 .* sin(3 * x);
% Initial observations
x1 = 0;
y1 = 0;
x2 = 1;
y2 = 0.03964061;
% Grid-based sample of the domain [0,1]
X = (0:0.01:1)';
y = objective(X); % Sample the domain without noise
% Find best result
[~, ix] = max(y);
% Initial observations
xsamples = [x1; x2];
ysamples = [y1; y2];
% Plot initial observations
figure;
plot(X, y);
hold on;
scatter(xsamples, ysamples);
title('Iteration 0');
legend('Real', 'Explored samples');
new_x_vector = [0.53; 0.91; 0.82; 0.87; 0.87];
for i = 1:5
% Fit Gaussian Process Regression
surrogate = fitrgp(xsamples, ysamples);
% Predict mean and standard deviation
[yhat, std] = predict(surrogate, X);
% Step
new_x = new_x_vector(i);
new_y = objective(new_x);
% Plot iteration results
figure;
plot(X, y);
hold on;
scatter(xsamples, ysamples, 'b');
plot(X, yhat, 'g', 'LineWidth', 1.5);
plot(X, yhat+std, 'g--');
plot(X, yhat-std, 'g--');
scatter(new_x, new_y, 'r');
title(['Iteration ', num2str(i)]);
legend('Real', 'Explored samples', 'Gaussian Process - Mean', 'Gaussian Process - Upper/Lower Bound', '', 'Next Sample', 'Location', 'northwest');
% Update observations
xsamples = [xsamples; new_x];
ysamples = [ysamples; new_y];
end
The differences are listed below:
  • Data Points: In the modified code snippet, the initial observations are defined explicitly as "x1", "y1", "x2", and "y2". In your code snippet, the observations are defined implicitly as "x" and "y" using the "plot" function.
  • Data Generation: In the modified code snippet, the target function "objective" is used to generate the target values "y_target" for visualization. In your code snippet, the target values "y_target" are directly calculated using the target function formula.
I hope this helps!

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!