import random from dataclasses import dataclass def make_pts(N): X = [] for i in range(N): x_1 = random.random() x_2 = random.random() X.append((x_1, x_2)) return X @dataclass class Graph: N: int X: list y: list def simple(N): X = make_pts(N) y = [] for x_1, x_2 in X: y1 = 1 if x_1 < 0.5 else 0 y.append(y1) return Graph(N, X, y) def split(N): X = make_pts(N) y = [] for x_1, x_2 in X: y1 = 1 if x_1 < 0.2 or x_1 > 0.8 else 0 y.append(y1) return Graph(N, X, y) def xor(N): X = make_pts(N) y = [] for x_1, x_2 in X: y1 = 1 if ((x_1 < 0.5 and x_2 > 0.5) or (x_1 > 0.5 and x_2 < 0.5)) else 0 y.append(y1) return Graph(N, X, y)