Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to #[classmethod] for SpendBundle::py_aggregate #678

Merged
merged 33 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a4740e8
switch to classmethod for py_aggregate
matt-o-how Aug 26, 2024
d1f02d3
import under cfg feature
matt-o-how Aug 26, 2024
6a7ee09
update stubs
matt-o-how Aug 26, 2024
5985332
add test for derived class returning correctly
matt-o-how Aug 26, 2024
8644b38
change pystreamable macro to use classmethods instead of staticmethods
matt-o-how Aug 27, 2024
7c85ba1
fmt and clippy
matt-o-how Aug 27, 2024
1ab56bf
fix line ending format
matt-o-how Aug 27, 2024
c602f5c
black test
matt-o-how Aug 27, 2024
6c2a324
add tests for the streamable macro functions
matt-o-how Aug 29, 2024
7185a6b
fix tests to actually test what we're expecting
matt-o-how Aug 29, 2024
cfe8665
fix aggregate
matt-o-how Aug 30, 2024
eaa227b
nonworking commit for arvid
matt-o-how Sep 3, 2024
91205ae
fixup
arvidn Sep 3, 2024
5f0bab7
add downcasting step to streamable classmethods that support it
matt-o-how Sep 5, 2024
c256de6
fix tests and remove duplicate imports
matt-o-how Sep 5, 2024
b058c2d
add from_parent for OwnedSpendConditions and OwnedSpendBundleConditions
matt-o-how Sep 6, 2024
2723e1f
fmt
matt-o-how Sep 6, 2024
34ab0f9
fix stubs and use ?
matt-o-how Sep 9, 2024
58bf3ed
pushing broken optional skip to work from laptop
matt-o-how Sep 10, 2024
26bed89
fix
matt-o-how Sep 10, 2024
773be19
update all remaining streamable macros to use check and skip
matt-o-how Sep 10, 2024
6ba0981
update stubs to reflect new Streamable
matt-o-how Sep 10, 2024
dbef1e9
use py as paramter instead of calling with_gil()
matt-o-how Sep 10, 2024
71933f9
Add NotImplemented error for unsupported from_parent() calls
matt-o-how Sep 10, 2024
7f82ad3
re-enable from_parent in SpendBundle
matt-o-how Sep 11, 2024
89888cf
make error messages struct specific
matt-o-how Sep 12, 2024
763c1a7
fmt
matt-o-how Sep 13, 2024
423d41d
fix if statement for from_parent skip
matt-o-how Sep 13, 2024
b62d9d4
clippy fixes
matt-o-how Sep 13, 2024
ef7ecbe
add from_parent skip to aggregate()
matt-o-how Sep 13, 2024
f316c7e
clippy fix
matt-o-how Sep 13, 2024
2448708
Remove final with_gil()
matt-o-how Sep 16, 2024
2ebcd94
fmt
matt-o-how Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions crates/chia-bls/src/public_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,8 @@ mod pytests {
let pk = sk.public_key();
Python::with_gil(|py| {
let string = pk.to_json_dict(py).expect("to_json_dict");
let pk2 = PublicKey::from_json_dict(string.bind(py)).unwrap();
let py_class = py.get_type_bound::<PublicKey>();
let pk2 = PublicKey::from_json_dict(&py_class, string.bind(py)).unwrap();
assert_eq!(pk, pk2);
});
}
Expand All @@ -752,8 +753,9 @@ mod pytests {
fn test_json_dict(#[case] input: &str, #[case] msg: &str) {
pyo3::prepare_freethreaded_python();
Python::with_gil(|py| {
let err =
PublicKey::from_json_dict(input.to_string().into_py(py).bind(py)).unwrap_err();
let py_class = py.get_type_bound::<PublicKey>();
let err = PublicKey::from_json_dict(&py_class, input.to_string().into_py(py).bind(py))
.unwrap_err();
assert_eq!(err.value_bound(py).to_string(), msg.to_string());
});
}
Expand Down
8 changes: 5 additions & 3 deletions crates/chia-bls/src/secret_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,8 @@ mod pytests {
let sk = SecretKey::from_seed(&data);
Python::with_gil(|py| {
let string = sk.to_json_dict(py).expect("to_json_dict");
let sk2 = SecretKey::from_json_dict(string.bind(py)).unwrap();
let py_class = py.get_type_bound::<SecretKey>();
let sk2 = SecretKey::from_json_dict(&py_class, string.bind(py)).unwrap();
assert_eq!(sk, sk2);
assert_eq!(sk.public_key(), sk2.public_key());
});
Expand Down Expand Up @@ -588,8 +589,9 @@ mod pytests {
fn test_json_dict(#[case] input: &str, #[case] msg: &str) {
pyo3::prepare_freethreaded_python();
Python::with_gil(|py| {
let err =
SecretKey::from_json_dict(input.to_string().into_py(py).bind(py)).unwrap_err();
let py_class = py.get_type_bound::<SecretKey>();
let err = SecretKey::from_json_dict(&py_class, input.to_string().into_py(py).bind(py))
.unwrap_err();
assert_eq!(err.value_bound(py).to_string(), msg.to_string());
});
}
Expand Down
8 changes: 5 additions & 3 deletions crates/chia-bls/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,8 @@ mod pytests {
let sig = sign(&sk, msg);
Python::with_gil(|py| {
let string = sig.to_json_dict(py).expect("to_json_dict");
let sig2 = Signature::from_json_dict(string.bind(py)).unwrap();
let py_class = py.get_type_bound::<Signature>();
let sig2 = Signature::from_json_dict(&py_class, string.bind(py)).unwrap();
assert_eq!(sig, sig2);
});
}
Expand All @@ -1274,8 +1275,9 @@ mod pytests {
fn test_json_dict(#[case] input: &str, #[case] msg: &str) {
pyo3::prepare_freethreaded_python();
Python::with_gil(|py| {
let err =
Signature::from_json_dict(input.to_string().into_py(py).bind(py)).unwrap_err();
let py_class = py.get_type_bound::<Signature>();
let err = Signature::from_json_dict(&py_class, input.to_string().into_py(py).bind(py))
.unwrap_err();
assert_eq!(err.value_bound(py).to_string(), msg.to_string());
});
}
Expand Down
8 changes: 5 additions & 3 deletions crates/chia-protocol/src/spend_bundle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use clvmr::ENABLE_FIXED_DIV;

#[cfg(feature = "py-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "py-bindings")]
use pyo3::types::PyType;

#[streamable(subclass)]
pub struct SpendBundle {
Expand Down Expand Up @@ -94,10 +96,10 @@ impl SpendBundle {
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl SpendBundle {
#[staticmethod]
#[classmethod]
#[pyo3(name = "aggregate")]
fn py_aggregate(spend_bundles: Vec<SpendBundle>) -> SpendBundle {
SpendBundle::aggregate(&spend_bundles)
fn py_aggregate(_cls: &Bound<'_, PyType>, spend_bundles: Vec<Self>) -> Self {
Self::aggregate(&spend_bundles)
arvidn marked this conversation as resolved.
Show resolved Hide resolved
}

#[pyo3(name = "name")]
Expand Down
16 changes: 8 additions & 8 deletions crates/chia_py_streamable_macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ pub fn py_streamable_macro(input: proc_macro::TokenStream) -> proc_macro::TokenS
py_protocol.extend(quote! {
#[pyo3::pymethods]
impl #ident {
#[staticmethod]
#[classmethod]
#[pyo3(signature=(json_dict))]
arvidn marked this conversation as resolved.
Show resolved Hide resolved
pub fn from_json_dict(json_dict: &pyo3::Bound<pyo3::PyAny>) -> pyo3::PyResult<Self> {
pub fn from_json_dict(_cls: &pyo3::Bound<'_, pyo3::types::PyType>, json_dict: &pyo3::Bound<pyo3::PyAny>) -> pyo3::PyResult<Self> {
<Self as #crate_name::from_json_dict::FromJsonDict>::from_json_dict(json_dict)
}

Expand All @@ -174,9 +174,9 @@ pub fn py_streamable_macro(input: proc_macro::TokenStream) -> proc_macro::TokenS
let streamable = quote! {
#[pyo3::pymethods]
impl #ident {
#[staticmethod]
#[classmethod]
#[pyo3(name = "from_bytes")]
pub fn py_from_bytes(blob: pyo3::buffer::PyBuffer<u8>) -> pyo3::PyResult<Self> {
pub fn py_from_bytes(_cls: &pyo3::Bound<'_, pyo3::types::PyType>, blob: pyo3::buffer::PyBuffer<u8>) -> pyo3::PyResult<Self> {
if !blob.is_c_contiguous() {
panic!("from_bytes() must be called with a contiguous buffer");
}
Expand All @@ -186,9 +186,9 @@ pub fn py_streamable_macro(input: proc_macro::TokenStream) -> proc_macro::TokenS
<Self as #crate_name::Streamable>::from_bytes(slice).map_err(|e| <#crate_name::chia_error::Error as Into<pyo3::PyErr>>::into(e))
}

#[staticmethod]
#[classmethod]
#[pyo3(name = "from_bytes_unchecked")]
pub fn py_from_bytes_unchecked(blob: pyo3::buffer::PyBuffer<u8>) -> pyo3::PyResult<Self> {
pub fn py_from_bytes_unchecked(_cls: &pyo3::Bound<'_, pyo3::types::PyType>, blob: pyo3::buffer::PyBuffer<u8>) -> pyo3::PyResult<Self> {
if !blob.is_c_contiguous() {
panic!("from_bytes_unchecked() must be called with a contiguous buffer");
}
Expand All @@ -199,9 +199,9 @@ pub fn py_streamable_macro(input: proc_macro::TokenStream) -> proc_macro::TokenS
}

// returns the type as well as the number of bytes read from the buffer
#[staticmethod]
#[classmethod]
#[pyo3(signature= (blob, trusted=false))]
pub fn parse_rust<'p>(blob: pyo3::buffer::PyBuffer<u8>, trusted: bool) -> pyo3::PyResult<(Self, u32)> {
pub fn parse_rust<'p>(_cls: &pyo3::Bound<'_, pyo3::types::PyType>, blob: pyo3::buffer::PyBuffer<u8>, trusted: bool) -> pyo3::PyResult<(Self, u32)> {
arvidn marked this conversation as resolved.
Show resolved Hide resolved
if !blob.is_c_contiguous() {
panic!("parse_rust() must be called with a contiguous buffer");
}
Expand Down
12 changes: 12 additions & 0 deletions tests/test_spend_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,15 @@ def test_spend_bundle(

rem = f"{removals}"
assert rem == expected_rem


class NewAndImprovedSpendBundle(PySpendBundle):
test_bool = True


def test_derive_class():
test = PySpendBundle.aggregate([])
assert isinstance(test, PySpendBundle)
test = NewAndImprovedSpendBundle.aggregate([])
assert isinstance(test, NewAndImprovedSpendBundle)
assert test.test_bool
18 changes: 9 additions & 9 deletions wheel/generate_type_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,19 @@ def __repr__(self) -> str: ...
def __richcmp__(self) -> Any: ...
def __deepcopy__(self) -> {name}: ...
def __copy__(self) -> {name}: ...
@staticmethod
def from_bytes(bytes) -> {name}: ...
@staticmethod
def from_bytes_unchecked(bytes) -> {name}: ...
@staticmethod
def parse_rust(ReadableBuffer, bool = False) -> Tuple[{name}, int]: ...
@classmethod
def from_bytes(cls, bytes) -> {name}: ...
arvidn marked this conversation as resolved.
Show resolved Hide resolved
@classmethod
def from_bytes_unchecked(cls, bytes) -> {name}: ...
@classmethod
def parse_rust(cls, ReadableBuffer, bool = False) -> Tuple[{name}, int]: ...
def to_bytes(self) -> bytes: ...
def __bytes__(self) -> bytes: ...
def stream_to_bytes(self) -> bytes: ...
def get_hash(self) -> bytes32: ...
def to_json_dict(self) -> Any: ...
@staticmethod
def from_json_dict(json_dict: Any) -> {name}: ...
@classmethod
def from_json_dict(cls, json_dict: Any) -> {name}: ...
"""
)

Expand Down Expand Up @@ -222,7 +222,7 @@ def parse_rust_source(filename: str, upper_case: bool) -> List[Tuple[str, List[s
"def uncurry(self) -> Tuple[ChiaProgram, ChiaProgram]: ...",
],
"SpendBundle": [
"@staticmethod\n def aggregate(sbs: List[SpendBundle]) -> SpendBundle: ...",
"@classmethod\n def aggregate(sbs: List[SpendBundle]) -> SpendBundle: ...",
arvidn marked this conversation as resolved.
Show resolved Hide resolved
"def name(self) -> bytes32: ...",
"def removals(self) -> List[Coin]: ...",
"def additions(self) -> List[Coin]: ...",
Expand Down
Loading
Loading