Skip to content

Commit

Permalink
Conditionally explode loop when loop bound <= 32
Browse files Browse the repository at this point in the history
  • Loading branch information
DSouzaM committed Mar 8, 2024
1 parent 1d8eeaa commit f29a28b
Showing 1 changed file with 166 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@
@ShortCircuitOperation(name = "PrimitiveBoolOr", booleanConverter = PBytecodeDSLRootNode.BooleanIdentity.class, operator = Operator.OR_RETURN_CONVERTED)
@SuppressWarnings("unused")
public abstract class PBytecodeDSLRootNode extends PRootNode implements BytecodeRootNode {
private static final int EXPLODE_LOOP_THRESHOLD = 32;

@Child protected transient PythonObjectFactory factory = PythonObjectFactory.create();
@Child private transient CalleeContext calleeContext = CalleeContext.create();
Expand Down Expand Up @@ -1474,7 +1475,6 @@ public static PList perform(Object[] elements,
@Operation
public static final class MakeSet {
@Specialization
@ExplodeLoop
public static PSet perform(VirtualFrame frame, Object[] elements,
@Bind("$root") PBytecodeDSLRootNode rootNode,
@Bind("this") Node node,
Expand All @@ -1485,11 +1485,26 @@ public static PSet perform(VirtualFrame frame, Object[] elements,
assert elements.length == length;

PSet set = rootNode.factory.createSet();
if (length <= EXPLODE_LOOP_THRESHOLD) {
doExploded(frame, set, elements, length, addNode, setItemNode);
} else {
doRegular(frame, set, elements, length, addNode, setItemNode);
}
return set;
}

@ExplodeLoop
private static void doExploded(VirtualFrame frame, PSet set, Object[] elements, int length, SetNodes.AddNode addNode, HashingCollectionNodes.SetItemNode setItemNode) {
CompilerAsserts.partialEvaluationConstant(length);
for (int i = 0; i < length; i++) {
SetNodes.AddNode.add(frame, set, elements[i], addNode, setItemNode);
}
}

return set;
private static void doRegular(VirtualFrame frame, PSet set, Object[] elements, int length, SetNodes.AddNode addNode, HashingCollectionNodes.SetItemNode setItemNode) {
for (int i = 0; i < length; i++) {
SetNodes.AddNode.add(frame, set, elements[i], addNode, setItemNode);
}
}
}

Expand All @@ -1505,22 +1520,43 @@ public static PSet perform(VirtualFrame frame,
@Operation
public static final class MakeFrozenSet {
@Specialization
@ExplodeLoop
public static PFrozenSet doFrozenSet(VirtualFrame frame, @Variadic Object[] elements,
public static PFrozenSet perform(VirtualFrame frame, @Variadic Object[] elements,
@Cached(value = "elements.length", neverDefault = false) int length,
@Cached HashingStorageSetItem hashingStorageLibrary,
@Bind("$root") PBytecodeDSLRootNode rootNode,
@Bind("this") Node inliningTarget) {
// TODO (GR-52217): make length a DSL constant.
assert elements.length == length;

HashingStorage setStorage;
if (length <= EXPLODE_LOOP_THRESHOLD) {
setStorage = doExploded(frame, inliningTarget, elements, length, hashingStorageLibrary);
} else {
setStorage = doRegular(frame, inliningTarget, elements, length, hashingStorageLibrary);
}
return rootNode.factory.createFrozenSet(setStorage);
}

@ExplodeLoop
private static HashingStorage doExploded(VirtualFrame frame, Node inliningTarget, Object[] elements, int length, HashingStorageSetItem hashingStorageLibrary) {
CompilerAsserts.partialEvaluationConstant(length);
HashingStorage setStorage = EmptyStorage.INSTANCE;
for (int i = 0; i < length; ++i) {
Object o = elements[i];
setStorage = hashingStorageLibrary.execute(frame, inliningTarget, setStorage, o, PNone.NONE);
}
return rootNode.factory.createFrozenSet(setStorage);
return setStorage;
}

private static HashingStorage doRegular(VirtualFrame frame, Node inliningTarget, Object[] elements, int length, HashingStorageSetItem hashingStorageLibrary) {
HashingStorage setStorage = EmptyStorage.INSTANCE;
for (int i = 0; i < length; ++i) {
Object o = elements[i];
setStorage = hashingStorageLibrary.execute(frame, inliningTarget, setStorage, o, PNone.NONE);
}
return setStorage;
}

}

@Operation
Expand All @@ -1534,13 +1570,24 @@ public static PList perform(@Bind("$root") PBytecodeDSLRootNode rootNode) {
@Operation
public static final class MakeTuple {
@Specialization
@ExplodeLoop
public static Object perform(@Variadic Object[] elements,
@Cached(value = "elements.length", neverDefault = false) int length,
@Bind("$root") PBytecodeDSLRootNode rootNode) {
// TODO (GR-52217): make length a DSL constant.
assert elements.length == length;

Object[] elts;
if (length <= EXPLODE_LOOP_THRESHOLD) {
elts = doExploded(elements, length);
} else {
elts = doRegular(elements, length);
}
return rootNode.factory.createTuple(elts);
}

@ExplodeLoop
private static Object[] doExploded(Object[] elements, int length) {
CompilerAsserts.partialEvaluationConstant(length);
int totalLength = 0;
for (int i = 0; i < length; i++) {
totalLength += ((Object[]) elements[i]).length;
Expand All @@ -1554,8 +1601,24 @@ public static Object perform(@Variadic Object[] elements,
System.arraycopy(arr, 0, elts, idx, len);
idx += len;
}
return elts;
}

return rootNode.factory.createTuple(elts);
private static Object[] doRegular(Object[] elements, int length) {
int totalLength = 0;
for (int i = 0; i < length; i++) {
totalLength += ((Object[]) elements[i]).length;
}

Object[] elts = new Object[totalLength];
int idx = 0;
for (int i = 0; i < length; i++) {
Object[] arr = (Object[]) elements[i];
int len = arr.length;
System.arraycopy(arr, 0, elts, idx, len);
idx += len;
}
return elts;
}
}

Expand Down Expand Up @@ -1697,13 +1760,31 @@ public static Object doGeneric(Object start, Object end, Object step,
@Operation
public static final class MakeKeywords {
@Specialization
@ExplodeLoop
public static PKeyword[] perform(@Variadic Object[] keysAndValues,
@Cached(value = "keysAndValues.length", neverDefault = true) int length) {
// TODO (GR-52217): make length a DSL constant.
assert keysAndValues.length == length;
assert length % 2 == 0;

if (length <= EXPLODE_LOOP_THRESHOLD) {
return doExploded(keysAndValues, length);
} else {
return doRegular(keysAndValues, length);
}
}

@ExplodeLoop
private static PKeyword[] doExploded(Object[] keysAndValues, int length) {
CompilerAsserts.partialEvaluationConstant(length);
PKeyword[] result = new PKeyword[length / 2];
for (int i = 0; i < length; i += 2) {
CompilerAsserts.compilationConstant(keysAndValues[i]);
result[i / 2] = new PKeyword((TruffleString) keysAndValues[i], keysAndValues[i + 1]);
}
return result;
}

private static PKeyword[] doRegular(Object[] keysAndValues, int length) {
PKeyword[] result = new PKeyword[length / 2];
for (int i = 0; i < length; i += 2) {
CompilerAsserts.compilationConstant(keysAndValues[i]);
Expand All @@ -1727,7 +1808,6 @@ public static PKeyword[] perform(Object sourceCollection,
@Operation
public static final class MakeDict {
@Specialization
@ExplodeLoop
public static PDict perform(VirtualFrame frame, @Variadic Object[] keysAndValues,
@Bind("$root") PBytecodeDSLRootNode rootNode,
@Cached(value = "keysAndValues.length", neverDefault = true) int length,
Expand All @@ -1736,6 +1816,17 @@ public static PDict perform(VirtualFrame frame, @Variadic Object[] keysAndValues
assert keysAndValues.length == length;

PDict dict = rootNode.factory.createDict();
if (length <= EXPLODE_LOOP_THRESHOLD) {
doExploded(frame, keysAndValues, length, updateNode, dict);
} else {
doRegular(frame, keysAndValues, length, updateNode, dict);
}
return dict;
}

@ExplodeLoop
private static void doExploded(VirtualFrame frame, Object[] keysAndValues, int length, DictNodes.UpdateNode updateNode, PDict dict) {
CompilerAsserts.partialEvaluationConstant(length);
for (int i = 0; i < length; i += 2) {
Object key = keysAndValues[i];
Object value = keysAndValues[i + 1];
Expand All @@ -1745,8 +1836,18 @@ public static PDict perform(VirtualFrame frame, @Variadic Object[] keysAndValues
dict.setItem(key, value);
}
}
}

return dict;
private static void doRegular(VirtualFrame frame, Object[] keysAndValues, int length, DictNodes.UpdateNode updateNode, PDict dict) {
for (int i = 0; i < length; i += 2) {
Object key = keysAndValues[i];
Object value = keysAndValues[i + 1];
if (key == PNone.NO_VALUE) {
updateNode.execute(frame, dict, value);
} else {
dict.setItem(key, value);
}
}
}
}

Expand Down Expand Up @@ -2569,18 +2670,27 @@ public static PCell[] doLoadClosure(VirtualFrame frame) {

@Operation
public static final class StoreRange {
@Specialization(guards = {"locals.length <= 32"})
@Specialization
public static void perform(VirtualFrame frame, Object[] values,
LocalSetterRange locals) {
if (values.length <= EXPLODE_LOOP_THRESHOLD) {
doExploded(frame, values, locals);
} else {
doRegular(frame, values, locals);
}
}

@ExplodeLoop
public static void doExploded(VirtualFrame frame, Object[] values,
private static void doExploded(VirtualFrame frame, Object[] values,
LocalSetterRange locals) {
CompilerAsserts.partialEvaluationConstant(values.length);
assert values.length == locals.getLength();
for (int i = 0; i < locals.length; i++) {
locals.setObject(frame, i, values[i]);
}
}

@Specialization(replaces = "doExploded")
public static void doRegular(VirtualFrame frame, Object[] values,
private static void doRegular(VirtualFrame frame, Object[] values,
LocalSetterRange locals) {
assert values.length == locals.getLength();
for (int i = 0; i < locals.length; i++) {
Expand All @@ -2603,12 +2713,21 @@ public static PCell[] doMakeCellArray(@Variadic Object[] cells) {
@Operation
public static final class Unstar {
@Specialization
@ExplodeLoop
public static Object[] doUnstar(@Variadic Object[] values,
public static Object[] perform(@Variadic Object[] values,
@Cached(value = "values.length", neverDefault = false) int length) {
// TODO (GR-52217): make length a DSL constant.
assert values.length == length;

if (length <= EXPLODE_LOOP_THRESHOLD) {
return doExploded(values, length);
} else {
return doRegular(values, length);
}
}

@ExplodeLoop
private static Object[] doExploded(Object[] values, int length) {
CompilerAsserts.partialEvaluationConstant(length);
int totalLength = 0;
for (int i = 0; i < length; i++) {
totalLength += ((Object[]) values[i]).length;
Expand All @@ -2620,7 +2739,21 @@ public static Object[] doUnstar(@Variadic Object[] values,
System.arraycopy(values[i], 0, result, idx, nl);
idx += nl;
}
return result;
}

private static Object[] doRegular(Object[] values, int length) {
int totalLength = 0;
for (int i = 0; i < length; i++) {
totalLength += ((Object[]) values[i]).length;
}
Object[] result = new Object[totalLength];
int idx = 0;
for (int i = 0; i < length; i++) {
int nl = ((Object[]) values[i]).length;
System.arraycopy(values[i], 0, result, idx, nl);
idx += nl;
}
return result;
}
}
Expand Down Expand Up @@ -3112,8 +3245,7 @@ public static void doExceptional(VirtualFrame frame,
@Operation
public static final class BuildString {
@Specialization
@ExplodeLoop
public static Object doBuildString(@Variadic Object[] strings,
public static Object perform(@Variadic Object[] strings,
@Cached(value = "strings.length", neverDefault = false) int length,
@Cached TruffleStringBuilder.AppendStringNode appendNode,
@Cached TruffleStringBuilder.ToStringNode toString) {
Expand All @@ -3122,11 +3254,26 @@ public static Object doBuildString(@Variadic Object[] strings,

TruffleStringBuilder tsb = TruffleStringBuilder.create(PythonUtils.TS_ENCODING);

if (length <= EXPLODE_LOOP_THRESHOLD) {
doExploded(strings, length, appendNode, tsb);
} else {
doRegular(strings, length, appendNode, tsb);
}
return toString.execute(tsb);
}

@ExplodeLoop
private static void doExploded(Object[] strings, int length, TruffleStringBuilder.AppendStringNode appendNode, TruffleStringBuilder tsb) {
CompilerAsserts.partialEvaluationConstant(length);
for (int i = 0; i < length; i++) {
appendNode.execute(tsb, (TruffleString) strings[i]);
}
}

return toString.execute(tsb);
private static void doRegular(Object[] strings, int length, TruffleStringBuilder.AppendStringNode appendNode, TruffleStringBuilder tsb) {
for (int i = 0; i < length; i++) {
appendNode.execute(tsb, (TruffleString) strings[i]);
}
}
}

Expand Down

0 comments on commit f29a28b

Please sign in to comment.