Skip to content

Commit

Permalink
Implement sq_repeat and nb_multiply CPython like slots
Browse files Browse the repository at this point in the history
  • Loading branch information
steve-s committed Sep 13, 2024
1 parent 7e3c3f7 commit 8b241b1
Show file tree
Hide file tree
Showing 42 changed files with 657 additions and 426 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,15 @@ enum SlotKind {
nb_bool("__bool__"),
/** foo + bar */
nb_add("__add__, __radd__"),
nb_multiply("__mul__, __rmul__"),
/** sequence length/size */
sq_length("__len__"),
/** sequence item: read element at index */
sq_item("__getitem__"),
/** seq + seq, nb_add is tried before */
sq_concat("__add__"),
/** seq * number, nb_multiply is tried before */
sq_repeat("__mul__"),
/** mapping length */
mp_length("__len__"),
/** mapping subscript, e.g. o[key], o[i:j] */
Expand Down
4 changes: 4 additions & 0 deletions graalpython/com.oracle.graal.python.cext/src/abstract.c
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,7 @@ PyNumber_Check(PyObject *o)
PyNumberMethods *nb = Py_TYPE(o)->tp_as_number;
return nb && (nb->nb_index || nb->nb_int || nb->nb_float || PyComplex_Check(o));
}
#endif // GraalPy

/* Binary operators */

Expand Down Expand Up @@ -1108,6 +1109,7 @@ ternary_op(PyObject *v,
return binary_op(v, w, NB_SLOT(op), op_name); \
}

#if 0 // GraalPy
BINARY_FUNC(PyNumber_Or, nb_or, "|")
BINARY_FUNC(PyNumber_Xor, nb_xor, "^")
BINARY_FUNC(PyNumber_And, nb_and, "&")
Expand Down Expand Up @@ -1754,6 +1756,7 @@ PySequence_Concat(PyObject *s, PyObject *o)
}
return type_error("'%.200s' object can't be concatenated", s);
}
#endif // GraalPy

PyObject *
PySequence_Repeat(PyObject *o, Py_ssize_t count)
Expand Down Expand Up @@ -1786,6 +1789,7 @@ PySequence_Repeat(PyObject *o, Py_ssize_t count)
return type_error("'%.200s' object can't be repeated", o);
}

#if 0 // GraalPy change
PyObject *
PySequence_InPlaceConcat(PyObject *s, PyObject *o)
{
Expand Down
2 changes: 1 addition & 1 deletion graalpython/com.oracle.graal.python.cext/src/capi.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ Py_LOCAL_SYMBOL int is_builtin_type(PyTypeObject *tp);
#define JWRAPPER_METHOD 8
#define JWRAPPER_UNSUPPORTED 9
#define JWRAPPER_ALLOC 10
#define JWRAPPER_SSIZE_ARG JWRAPPER_ALLOC
#define JWRAPPER_GETATTR 11
#define JWRAPPER_SETATTR 12
#define JWRAPPER_RICHCMP 13
Expand Down Expand Up @@ -232,6 +231,7 @@ Py_LOCAL_SYMBOL int is_builtin_type(PyTypeObject *tp);
#define JWRAPPER_REPR 45
#define JWRAPPER_DESCR_DELETE 46
#define JWRAPPER_DELATTRO 47
#define JWRAPPER_SSIZE_ARG 48


static inline int get_method_flags_wrapper(int flags) {
Expand Down
1 change: 1 addition & 0 deletions graalpython/com.oracle.graal.python.cext/src/typeobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -9383,6 +9383,7 @@ static int type_ready_graalpy_slot_conv(PyTypeObject* cls) {
ADD_SLOT_CONV("__len__", sequences->sq_length, -1, JWRAPPER_LENFUNC);
ADD_SLOT_CONV("__add__", sequences->sq_concat, -2, JWRAPPER_BINARYFUNC);
ADD_SLOT_CONV("__mul__", sequences->sq_repeat, -2, JWRAPPER_SSIZE_ARG);
ADD_SLOT_CONV("__rmul__", sequences->sq_repeat, -2, JWRAPPER_SSIZE_ARG);
ADD_SLOT_CONV("__getitem__", sequences->sq_item, -2, JWRAPPER_GETITEM);
ADD_SLOT_CONV("__setitem__", sequences->sq_ass_item, -3, JWRAPPER_SETITEM);
ADD_SLOT_CONV("__delitem__", sequences->sq_ass_item, -3, JWRAPPER_DELITEM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ private static String getSuffix(boolean isComplex) {
static String getSlotBaseClass(Slot s) {
return switch (s.value()) {
case nb_bool -> "TpSlotInquiry.TpSlotInquiryBuiltin";
case nb_add -> "TpSlotBinaryOp.TpSlotBinaryOpBuiltin";
case nb_add, nb_multiply -> "TpSlotBinaryOp.TpSlotBinaryOpBuiltin";
case sq_concat -> "TpSlotBinaryFunc.TpSlotSqConcat";
case sq_length, mp_length -> "TpSlotLen.TpSlotLenBuiltin" + getSuffix(s.isComplex());
case sq_item -> "TpSlotSizeArgFun.TpSlotSizeArgFunBuiltin";
case sq_item, sq_repeat -> "TpSlotSizeArgFun.TpSlotSizeArgFunBuiltin";
case mp_subscript -> "TpSlotBinaryFunc.TpSlotMpSubscript";
case tp_getattro -> "TpSlotGetAttr.TpSlotGetAttrBuiltin";
case tp_descr_get -> "TpSlotDescrGet.TpSlotDescrGetBuiltin" + getSuffix(s.isComplex());
Expand All @@ -68,10 +68,11 @@ static String getSlotNodeBaseClass(Slot s) {
return switch (s.value()) {
case tp_descr_get -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotDescrGet.DescrGetBuiltinNode";
case nb_bool -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotInquiry.NbBoolBuiltinNode";
case nb_add -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotBinaryOp.BinaryOpBuiltinNode";
case nb_add, nb_multiply -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotBinaryOp.BinaryOpBuiltinNode";
case sq_concat -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotBinaryFunc.SqConcatBuiltinNode";
case sq_length, mp_length -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotLen.LenBuiltinNode";
case sq_item -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotSizeArgFun.SqItemBuiltinNode";
case sq_repeat -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotSizeArgFun.SqRepeatBuiltinNode";
case mp_subscript -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotBinaryFunc.MpSubscriptBuiltinNode";
case tp_getattro -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotGetAttr.GetAttrBuiltinNode";
case tp_descr_set -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotDescrSet.DescrSetBuiltinNode";
Expand All @@ -84,7 +85,7 @@ static String getUncachedExecuteSignature(SlotKind s) {
case nb_bool -> "boolean executeUncached(Object self)";
case tp_descr_get -> "Object executeUncached(Object self, Object obj, Object type)";
case sq_length, mp_length -> "int executeUncached(Object self)";
case tp_getattro, tp_descr_set, tp_setattro, sq_item, mp_subscript, nb_add, sq_concat ->
case tp_getattro, tp_descr_set, tp_setattro, sq_item, mp_subscript, nb_add, sq_concat, sq_repeat, nb_multiply ->
throw new AssertionError("Should not reach here: should be always complex");
};
}
Expand All @@ -93,15 +94,18 @@ static boolean supportsComplex(SlotKind s) {
return switch (s) {
case nb_bool -> false;
case sq_length, mp_length, tp_getattro, tp_descr_get, tp_descr_set,
tp_setattro, sq_item, mp_subscript, nb_add, sq_concat ->
tp_setattro, sq_item, mp_subscript, nb_add, sq_concat,
sq_repeat, nb_multiply ->
true;
};
}

static boolean supportsSimple(SlotKind s) {
return switch (s) {
case nb_bool, sq_length, mp_length, tp_descr_get -> true;
case tp_getattro, tp_descr_set, tp_setattro, sq_item, mp_subscript, nb_add, sq_concat -> false;
case tp_getattro, tp_descr_set, tp_setattro, sq_item, mp_subscript,
nb_add, sq_concat, sq_repeat, nb_multiply ->
false;
};
}

Expand All @@ -110,18 +114,17 @@ static String getUncachedExecuteCall(SlotKind s) {
case nb_bool -> "executeBool(null, self)";
case sq_length, mp_length -> "executeInt(null, self)";
case tp_descr_get -> "execute(null, self, obj, type)";
case tp_getattro, tp_descr_set, tp_setattro, sq_item, mp_subscript, nb_add, sq_concat ->
case tp_getattro, tp_descr_set, tp_setattro, sq_item, mp_subscript,
nb_add, sq_concat, nb_multiply, sq_repeat ->
throw new AssertionError("Should not reach here: should be always complex");
};
}

public static String getExtraCtorArgs(TpSlotData slot) {
return switch (slot.slot().value()) {
case nb_add -> ", com.oracle.graal.python.nodes.SpecialMethodNames.J___ADD__";
case nb_bool, tp_setattro, tp_getattro,
tp_descr_set, tp_descr_get, mp_subscript,
mp_length, sq_concat, sq_item, sq_length ->
"";
case nb_multiply -> ", com.oracle.graal.python.nodes.SpecialMethodNames.J___MUL__";
default -> "";
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright (c) 2024, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* The Universal Permissive License (UPL), Version 1.0
*
* Subject to the condition set forth below, permission is hereby granted to any
* person obtaining a copy of this software, associated documentation and/or
* data (collectively the "Software"), free of charge and under any and all
* copyright rights in the Software, and any and all patent rights owned or
* freely licensable by each licensor hereunder covering either (i) the
* unmodified Software as contributed to or provided by such licensor, or (ii)
* the Larger Works (as defined below), to deal in both
*
* (a) the Software, and
*
* (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if
* one is included with the Software each a "Larger Work" to which the Software
* is contributed by such licensors),
*
* without restriction, including without limitation the rights to copy, create
* derivative works of, display, perform, and distribute the Software and make,
* use, sell, offer for sale, import, export, have made, and have sold the
* Software and the Larger Work(s), and to sublicense the foregoing rights on
* either these or other terms.
*
* This license is subject to the following condition:
*
* The above copyright notice and either this complete permission notice or at a
* minimum a reference to the UPL must be included in all copies or substantial
* portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
// Generated by the slots_fuzzer.py
#include <Python.h>

PyObject *global_stash1;
PyObject *global_stash2;

int Native0_nb_bool(PyObject *self) { return 1; }
PyObject *Native0_nb_add(PyObject *self, PyObject *other) {
return Py_NewRef(self);
}
Py_ssize_t Native0_sq_length(PyObject *self) { return 1; }
PyObject *Native0_sq_concat(PyObject *self, PyObject *other) {
return PyLong_FromLong(10);
}
PyObject *Native0_sq_repeat(PyObject *self, Py_ssize_t count) {
return PyLong_FromLong(count);
}
Py_ssize_t Native0_mp_length(PyObject *self) { return 42; }
PyObject *Native0_tp_getattr(PyObject *self, char *name) {
return Py_NewRef(self);
}
PyObject *Native0_tp_getattro(PyObject *self, PyObject *name) {
return Py_NewRef(self);
}
PyObject *Native0_tp_descr_get(PyObject *self, PyObject *key, PyObject *type) {
Py_RETURN_NONE;
}
int Native0_tp_descr_set(PyObject *self, PyObject *key, PyObject *value) {
return 0;
}

PyNumberMethods Native0_tp_as_number = {
.nb_bool = &Native0_nb_bool,
.nb_add = &Native0_nb_add,
};
PySequenceMethods Native0_tp_as_sequence = {
.sq_length = &Native0_sq_length,
.sq_concat = &Native0_sq_concat,
.sq_repeat = &Native0_sq_repeat,
};
PyMappingMethods Native0_tp_as_mapping = {
.mp_length = &Native0_mp_length,
};

static PyTypeObject CustomType_Native0 = {
.ob_base = PyVarObject_HEAD_INIT(NULL, 0).tp_name = "test10.Native0",
.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
.tp_new = PyType_GenericNew,
.tp_as_number = &Native0_tp_as_number,
.tp_as_sequence = &Native0_tp_as_sequence,
.tp_as_mapping = &Native0_tp_as_mapping,
.tp_getattr = &Native0_tp_getattr,
.tp_getattro = &Native0_tp_getattro,
.tp_descr_get = &Native0_tp_descr_get,
.tp_descr_set = &Native0_tp_descr_set,

};

static PyObject *create_Native0(PyObject *module, PyObject *args) {
if (PyType_Ready(&CustomType_Native0) < 0)
return NULL;
Py_INCREF(&CustomType_Native0);
return (PyObject *)&CustomType_Native0;
}

static struct PyMethodDef test_module_methods[] = {
{"create_Native0", (PyCFunction)create_Native0, METH_VARARGS, ""},
{NULL, NULL, 0, NULL}};
static PyModuleDef test_module = {PyModuleDef_HEAD_INIT,
"fuzzer_test10",
"",
-1,
test_module_methods,
NULL,
NULL,
NULL,
NULL};

PyMODINIT_FUNC PyInit_fuzzer_test10(void) {
return PyModule_Create(&test_module);
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@

__dir__ = __file__.rpartition("/")[0]

from .test_modsupport import MySeq


def _safe_check(v, type_check):
try:
Expand Down Expand Up @@ -291,13 +293,40 @@ def __len__(self):
return 42


class SeqWithMulAdd:
def __getitem__(self, item):
return item

def __mul__(self, other):
return "mul:" + str(other)

def __add__(self, other):
return "add:" + str(other)

def __str__(self):
return "SeqWithMulAdd"

class NonSeqWithMulAdd:
def __mul__(self, other):
return "not expected!"

def __add__(self, other):
return "not expected!"


class DictSubclassWithSequenceMethods(dict):
def __getitem__(self, key):
return key

def __setitem__(self, key, value):
pass

def __add__(self, other):
return "not expected!"

def __mul__(self, other):
return "not expected!"


def _default_bin_arith_args():
return (
Expand Down Expand Up @@ -1487,8 +1516,17 @@ def _reference_delslice(args):
cmpfunc=unhandled_error_compare
)

def _reference_seq_repeat(args):
match args[0]:
case SeqWithMulAdd():
return "mul:" + str(args[1])
case NonSeqWithMulAdd() | DictSubclassWithSequenceMethods():
raise TypeError(f"{type(args[1])} object can't be repeated")
case _:
return args[0] * args[1]

test_PySequence_Repeat = CPyExtFunction(
lambda args: args[0] * args[1],
_reference_seq_repeat,
lambda: (
((1,), 0),
((1,), 1),
Expand All @@ -1500,6 +1538,9 @@ def _reference_delslice(args):
("hello", 1),
("hello", 3),
({}, 0),
(SeqWithMulAdd(), 42),
(NonSeqWithMulAdd(), 24),
(DictSubclassWithSequenceMethods(), 5),
),
resultspec="O",
argspec='On',
Expand Down Expand Up @@ -1527,8 +1568,19 @@ def _reference_delslice(args):
cmpfunc=unhandled_error_compare
)

def _reference_seq_concat(args):
match args[0]:
case SeqWithMulAdd():
if hasattr(args[1], "__getitem__"):
return "add:" + str(args[1])
raise TypeError("SeqWithMulAdd object can't be concatenated")
case NonSeqWithMulAdd() | DictSubclassWithSequenceMethods():
raise TypeError(f"{type(args[1])} object can't be concatenated")
case _:
return args[0] + args[1]

test_PySequence_Concat = CPyExtFunction(
lambda args: args[0] + args[1],
_reference_seq_concat,
lambda: (
((1,), tuple()),
((1,), list()),
Expand All @@ -1542,6 +1594,13 @@ def _reference_delslice(args):
("hello", ""),
({}, []),
([], {}),
(SeqWithMulAdd(), 1),
(SeqWithMulAdd(), SeqWithMulAdd()),
(SeqWithMulAdd(), [1,2,3]),
(NonSeqWithMulAdd(), 2),
(NonSeqWithMulAdd(), [1,2,3]),
(DictSubclassWithSequenceMethods(), (1,2,3)),
((1,2,3), DictSubclassWithSequenceMethods()),
),
resultspec="O",
argspec='OO',
Expand Down
Loading

0 comments on commit 8b241b1

Please sign in to comment.