1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| def generate_data(n, p=0.8): x11 = np.random.multivariate_normal([4., 3.], [[4., 0.], [0., 1.]], int(0.5*n)) x12 = np.random.multivariate_normal([2., -2.], [[1., 0.], [0., 2.]], int(0.25*n)) x13 = np.random.multivariate_normal([7., -4.], [[1., 0.], [0., 1.]], int(0.25 * n)) x1 = np.vstack((x11, x12, x13)) plt.scatter(x1.T[0], x1.T[1], color="red") x2 = np.random.multivariate_normal([6., 0.], [[1.5, 0.5], [0.5, 1.5]], n) plt.scatter(x2.T[0], x2.T[1], color="blue") x = np.vstack((x1, x2)) y = np.asarray([[1., 0.]] * n + [[0., 1.]] * n) shuffle_idx = np.arange(0, n*2) np.random.shuffle(shuffle_idx) x_shuffled = x[shuffle_idx] y_shuffled = y[shuffle_idx] _x_train = x_shuffled[0:int(n * p)*2] _y_train = y_shuffled[0:int(n * p)*2] _x_test = x_shuffled[int(n * p)*2:n*2] _y_test = y_shuffled[int(n * p)*2:n*2] return _x_train, _y_train, _x_test, _y_test
|