Skip to content

Commit

Permalink
added splitter
Browse files Browse the repository at this point in the history
  • Loading branch information
umairjavaid committed Sep 8, 2023
1 parent 2601ef5 commit cd87ce9
Showing 1 changed file with 349 additions and 1 deletion.
350 changes: 349 additions & 1 deletion ivy/functional/frontends/sklearn/tree/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def init_split(split_record, start_pos):


class Splitter:
"""Abstract splitter class.
Splitters are called by tree builders to find the best splits on both
sparse and dense data, one split at a time.
"""
def __init__(self, criterion, max_features, min_samples_leaf, min_weight_leaf, random_state):
self.criterion = criterion
self.random_state = random_state
Expand Down Expand Up @@ -112,4 +117,347 @@ def node_value(self, dest):

def node_impurity(self):
"""Return the impurity of the current node."""
return self.criterion.node_impurity()
return self.criterion.node_impurity()

class BestSplitter(Splitter):
"""Splitter for finding the best split on dense data."""

def __init__(self, X, y, sample_weight, missing_values_in_feature_mask):
super().__init__(X, y, sample_weight, missing_values_in_feature_mask)
self.partitioner = DensePartitioner(
X, self.samples, self.feature_values, missing_values_in_feature_mask
)

def node_split(self, impurity, split, n_constant_features):
return node_split_best(
self,
self.partitioner,
self.criterion,
impurity,
split,
n_constant_features,
)

class DensePartitioner:
"""Partitioner specialized for dense data.
Note that this partitioner is agnostic to the splitting strategy (best vs. random).
"""

def __init__(self, X, samples, feature_values, missing_values_in_feature_mask):
self.X = X
self.samples = samples
self.feature_values = feature_values
self.missing_values_in_feature_mask = missing_values_in_feature_mask

def init_node_split(self, start, end):
"""Initialize splitter at the beginning of node_split."""
self.start = start
self.end = end
self.n_missing = 0

def sort_samples_and_feature_values(self, current_feature):
i = 0
current_end = 0
feature_values = self.feature_values
X = self.X
samples = self.samples
n_missing = 0
missing_values_in_feature_mask = self.missing_values_in_feature_mask

if missing_values_in_feature_mask is not None and missing_values_in_feature_mask[current_feature]:
i, current_end = self.start, self.end - 1
while i <= current_end:
if isnan(X[samples[current_end], current_feature]):
n_missing += 1
current_end -= 1
continue

if isnan(X[samples[i], current_feature]):
samples[i], samples[current_end] = samples[current_end], samples[i]
n_missing += 1
current_end -= 1

feature_values[i] = X[samples[i], current_feature]
i += 1
else:
for i in range(self.start, self.end):
feature_values[i] = X[samples[i], current_feature]

# Sorting algorithm not shown here; implement the sorting logic separately.

self.n_missing = n_missing



def sort(feature_values, samples, n):
if n == 0:
return
maxd = 2 * int(math.log(n))
introsort(feature_values, samples, n, maxd)

def introsort(feature_values, samples, n, maxd):
while n > 1:
if maxd <= 0: # max depth limit exceeded ("gone quadratic")
# Implement or import heapsort function
heapsort(feature_values, samples, n)
return
maxd -= 1

pivot = median3(feature_values, n)

# Three-way partition.
i = l = 0
r = n
while i < r:
if feature_values[i] < pivot:
swap(feature_values, samples, i, l)
i += 1
l += 1
elif feature_values[i] > pivot:
r -= 1
swap(feature_values, samples, i, r)
else:
i += 1

introsort(feature_values[:l], samples[:l], l, maxd)
feature_values = feature_values[r:]
samples = samples[r:]
n -= r

def heapsort(feature_values, samples, n):
# Heapify
start = (n - 2) // 2
end = n
while True:
sift_down(feature_values, samples, start, end)
if start == 0:
break
start -= 1

# Sort by shrinking the heap, putting the max element immediately after it
end = n - 1
while end > 0:
swap(feature_values, samples, 0, end)
sift_down(feature_values, samples, 0, end)
end -= 1

def sift_down(feature_values, samples, start, end):
# Restore heap order in feature_values[start:end] by moving the max element to start.
root = start
while True:
child = root * 2 + 1

# Find the max of root, left child, right child
maxind = root
if child < end and feature_values[samples[maxind]] < feature_values[samples[child]]:
maxind = child
if child + 1 < end and feature_values[samples[maxind]] < feature_values[samples[child + 1]]:
maxind = child + 1

if maxind == root:
break
else:
swap(feature_values, samples, root, maxind)
root = maxind


# Define the swap function here
def swap(feature_values, samples, i, j):
feature_values[samples[i]], feature_values[samples[j]] = feature_values[samples[j]], feature_values[samples[i]]
samples[i], samples[j] = samples[j], samples[i]


def median3(feature_values, n):
# Median of three pivot selection, after Bentley and McIlroy (1993).
# Engineering a sort function. SP&E. Requires 8/3 comparisons on average.
a, b, c = feature_values[0], feature_values[n // 2], feature_values[n - 1]
if a < b:
if b < c:
return b
elif a < c:
return c
else:
return a
elif b < c:
if a < c:
return a
else:
return c
else:
return b

import random

def node_split_best(splitter, partitioner, criterion, impurity, split, n_constant_features):
start = splitter.start
end = splitter.end
end_non_missing = end
n_missing = 0
has_missing = False
n_searches = 2 if has_missing else 1
best_split = SplitRecord()
best_proxy_improvement = -float("inf")

features = list(splitter.features)
constant_features = list(splitter.constant_features)
n_features = splitter.n_features
feature_values = list(splitter.feature_values)
max_features = splitter.max_features
min_samples_leaf = splitter.min_samples_leaf
min_weight_leaf = splitter.min_weight_leaf
random_state = random.Random(splitter.rand_r_state)

n_visited_features = 0
n_found_constants = 0
n_drawn_constants = 0
n_known_constants = n_constant_features[0]
n_total_constants = n_known_constants

while f_i > n_total_constants and (n_visited_features < max_features or n_visited_features <= n_found_constants + n_drawn_constants):
n_visited_features += 1
f_j = random.randint(n_drawn_constants, f_i - n_found_constants - 1)

if f_j < n_known_constants:
features[n_drawn_constants], features[f_j] = features[f_j], features[n_drawn_constants]
n_drawn_constants += 1
continue

f_j += n_found_constants
current_split.feature = features[f_j]
partitioner.sort_samples_and_feature_values(current_split.feature)
n_missing = partitioner.n_missing
end_non_missing = end - n_missing

if end_non_missing == start or feature_values[end_non_missing - 1] <= feature_values[start] + FEATURE_THRESHOLD:
features[f_j], features[n_total_constants] = features[n_total_constants], features[f_j]
n_found_constants += 1
n_total_constants += 1
continue

f_i -= 1
features[f_i], features[f_j] = features[f_j], features[f_i]
has_missing = n_missing != 0

if has_missing:
criterion.init_missing(n_missing)

n_searches = 2 if has_missing else 1

for i in range(n_searches):
missing_go_to_left = i == 1
criterion.missing_go_to_left = missing_go_to_left
criterion.reset()
p = start

while p < end_non_missing:
partitioner.next_p(p_prev, p)

if p >= end_non_missing:
continue

if missing_go_to_left:
n_left = p - start + n_missing
n_right = end_non_missing - p
else:
n_left = p - start
n_right = end_non_missing - p + n_missing

if n_left < min_samples_leaf or n_right < min_samples_leaf:
continue

current_split.pos = p
criterion.update(current_split.pos)

if criterion.weighted_n_left < min_weight_leaf or criterion.weighted_n_right < min_weight_leaf:
continue

current_proxy_improvement = criterion.proxy_impurity_improvement()

if current_proxy_improvement > best_proxy_improvement:
best_proxy_improvement = current_proxy_improvement
current_split.threshold = (feature_values[p_prev] / 2.0 + feature_values[p] / 2.0)

if current_split.threshold == feature_values[p] or current_split.threshold == float("inf") or current_split.threshold == -float("inf"):
current_split.threshold = feature_values[p_prev]

current_split.n_missing = n_missing
if n_missing == 0:
current_split.missing_go_to_left = n_left > n_right
else:
current_split.missing_go_to_left = missing_go_to_left

best_split = current_split

if has_missing:
n_left, n_right = end - start - n_missing, n_missing
p = end - n_missing
missing_go_to_left = False

if not (n_left < min_samples_leaf or n_right < min_samples_leaf):
criterion.missing_go_to_left = missing_go_to_left
criterion.update(p)

if not (criterion.weighted_n_left < min_weight_leaf or criterion.weighted_n_right < min_weight_leaf):
current_proxy_improvement = criterion.proxy_impurity_improvement()

if current_proxy_improvement > best_proxy_improvement:
best_proxy_improvement = current_proxy_improvement
current_split.threshold = float("inf")
current_split.missing_go_to_left = missing_go_to_left
current_split.n_missing = n_missing
current_split.pos = p
best_split = current_split

if best_split.pos < end:
partitioner.partition_samples_final(
best_split.pos,
best_split.threshold,
best_split.feature,
best_split.n_missing
)

if best_split.n_missing != 0:
criterion.init_missing(best_split.n_missing)
criterion.missing_go_to_left = best_split.missing_go_to_left
criterion.reset()
criterion.update(best_split.pos)
criterion.children_impurity(
best_split.impurity_left,
best_split.impurity_right
)
best_split.improvement = criterion.impurity_improvement(
impurity,
best_split.impurity_left,
best_split.impurity_right
)
shift_missing_values_to_left_if_required(
best_split,
samples,
end
)

memcpy(features, constant_features, sizeof(SIZE_t) * n_known_constants)
memcpy(
constant_features[n_known_constants:],
features[n_known_constants:n_known_constants + n_found_constants],
sizeof(SIZE_t) * n_found_constants
)

split[0] = best_split
n_constant_features[0] = n_total_constants
return 0

def shift_missing_values_to_left_if_required(best, samples, end):
# The partitioner partitions the data such that the missing values are in
# samples[-n_missing:] for the criterion to consume. If the missing values
# are going to the right node, then the missing values are already in the
# correct position. If the missing values go left, then we move the missing
# values to samples[best.pos:best.pos+n_missing] and update `best.pos`.
if best.n_missing > 0 and best.missing_go_to_left:
for p in range(best.n_missing):
i = best.pos + p
current_end = end - 1 - p
samples[i], samples[current_end] = samples[current_end], samples[i]
best.pos += best.n_missing

0 comments on commit cd87ce9

Please sign in to comment.