Skip to content

Commit

Permalink
WL final and all iterations option; some code refactoring; add TODOs
Browse files Browse the repository at this point in the history
  • Loading branch information
DillonZChen committed Aug 19, 2023
1 parent 41c34de commit 1e772cd
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 28 deletions.
62 changes: 37 additions & 25 deletions learner/kernels/wl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,58 @@


class WeisfeilerLehmanKernel(Kernel):
def __init__(self, iterations: int) -> None:
def __init__(self, iterations: int, all_colours: bool) -> None:
super().__init__()

# hashes neighbour multisets of colours; also acts as colour to explicit feature index
# hashes neighbour multisets of colours; same as self._representation if all_colours
self._hash = {}

# option for returning only final WL iteration
self._representation = {}

# number of wl iterations
self.iterations = iterations

# collect colours from all iterations or only final
self.all_colours = all_colours

def _get_hash_value(self, colour) -> int:
if colour not in self._hash:
self._hash[colour] = len(self._hash)
return self._hash[colour]

def read_train_data(self, graphs: CGraph) -> None:
""" Read data and precompute the hash function """

t = time.time()
self._train_data_colours = {}

# initial run to compute colours and hashmap
# compute colours and hashmap from training data
for G in graphs:
cur_colours = {}
histogram = {}

def store_colour(colour):
nonlocal histogram
if colour not in self._representation:
self._representation[colour] = len(self._representation)
if colour not in histogram:
histogram[colour] = 0
histogram[colour] += 1

# collect initial colours
for u in G.nodes:

# initial colour is feature of the node
colour = G.nodes[u]["colour"]
cur_colours[u] = self._get_hash_value(colour)

# check if colour in hash to compress
if colour not in self._hash:
self._hash[colour] = len(self._hash)
cur_colours[u] = self._hash[colour]

# store histogram throughout all iterations
if colour not in histogram:
histogram[colour] = 0
histogram[colour] += 1
# store histogram for all iterations or only last
if self.all_colours or self.iterations == 0:
store_colour(colour)

# WL iterations
for _ in range(self.iterations):
for itr in range(self.iterations):
new_colours = {}
for u in G.nodes:

Expand All @@ -52,21 +66,19 @@ def read_train_data(self, graphs: CGraph) -> None:
neighbour_colours.append((colour_node, colour_edge))
neighbour_colours = sorted(neighbour_colours)
colour = tuple([cur_colours[u]] + neighbour_colours)
new_colours[u] = self._get_hash_value(colour)

# check if colour in hash to compress
if colour not in self._hash:
self._hash[colour] = len(self._hash)
new_colours[u] = self._hash[colour]

# store histogram throughout all iterations
if colour not in histogram:
histogram[colour] = 0
histogram[colour] += 1
# store histogram for all iterations or only last
if self.all_colours or itr == self.iterations - 1:
store_colour(colour)
cur_colours = new_colours

# store histogram of graph colours over *all* iterations
# store histogram of graph colours
self._train_data_colours[G] = histogram

if self.all_colours:
self._representation = self._hash

t = time.time() - t
print(f"Initialised WL for {len(graphs)} graphs in {t:.2f}s")
print(f"Collected {len(self._hash)} colours over {sum(len(G.nodes) for G in graphs)} nodes")
Expand All @@ -77,12 +89,12 @@ def get_x(self, graphs: CGraph) -> np.array:
O(nd) time; n x d output
"""
n = len(graphs)
d = len(self._hash)
d = len(self._representation)
X = np.zeros((n, d))
for i, G in enumerate(graphs):
histogram = self._train_data_colours[G]
for colour in histogram:
j = self._hash[colour]
j = self._representation[colour]
X[i][j] = histogram[colour]
return X

Expand Down
2 changes: 2 additions & 0 deletions learner/representation/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def convert_to_coloured_graph(self) -> None:
efficiently for each graph representation separately but takes more effort.
"""

# TODO optimise by converting node string names into ints and storing the map

colours = set()

c_graph = self._create_graph()
Expand Down
14 changes: 11 additions & 3 deletions learner/train_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from util.metrics import f1_macro
from sklearn.svm import LinearSVR, SVR
from sklearn.model_selection import cross_validate
from sklearn.metrics import make_scorer, f1_score, mean_squared_error
from sklearn.metrics import make_scorer, mean_squared_error


_MODELS = [
Expand All @@ -19,17 +19,21 @@
]

_CV_FOLDS = 5
_MAX_MODEL_ITER = 1000
_MAX_MODEL_ITER = 10000

def create_parser():
parser = argparse.ArgumentParser()

parser.add_argument('-r', '--rep', type=str, required=True, choices=representation.REPRESENTATIONS,
help="graph representation to use")
# TODO implement CGraph for SLG

parser.add_argument('-k', '--kernel', type=str, required=True, choices=kernels.KERNELS,
help="graph representation to use")
parser.add_argument('-l', '--iterations', type=int, default=5,
help="number of iterations for kernel algorithms")
parser.add_argument('--final-only', dest="all_colours", action="store_false",
help="collects colours from only final iteration of WL kernels")

parser.add_argument('-m', '--model', type=str, default="linear-svr", choices=_MODELS,
help="ML model")
Expand Down Expand Up @@ -58,8 +62,12 @@ def create_parser():

np.random.seed(args.seed)

print(f"Initialising {args.kernel}...")
graphs, y = get_dataset_from_args_kernels(args)
kernel = kernels.KERNELS[args.kernel](args.iterations)
kernel = kernels.KERNELS[args.kernel](
iterations=args.iterations,
all_colours=args.all_colours,
)
kernel.read_train_data(graphs)

print(f"Setting up training data and initialising model...")
Expand Down

0 comments on commit 1e772cd

Please sign in to comment.