diff --git a/binaries.json b/binaries.json index e5e7a17a7144..d74dcf6f8531 100644 --- a/binaries.json +++ b/binaries.json @@ -5,11 +5,9 @@ { "utils": [ "C", - "CC", "CD", "CI", "CL", - "CM", "CV", "CX", "D", @@ -43,7 +41,6 @@ "III", "IIL", "IIM", - "IIV", "IIX", "IL", "ILC", @@ -62,7 +59,6 @@ "IMV", "IMX", "IV", - "IVC", "IVD", "IVI", "IVL", @@ -103,7 +99,13 @@ "VCV", "VCX", "VD", + "VDC", + "VDD", "VDI", + "VDL", + "VDM", + "VDV", + "VDX", "VI", "VIC", "VID", @@ -149,4 +151,4 @@ } ] } -} +} \ No newline at end of file diff --git a/ivy/compiler/compiler.py b/ivy/compiler/compiler.py index 4cec6f1c0d97..c59dc9a8340d 100644 --- a/ivy/compiler/compiler.py +++ b/ivy/compiler/compiler.py @@ -2,15 +2,15 @@ def clear_graph_cache(): - """Clears the graph cache which gets populated if `graph_caching` is set to - `True` in `ivy.trace_graph`, `ivy.transpile` or `ivy.unify`. Use this to + """Clears the graph cache which gets populated if `graph_caching` is set + to `True` in `ivy.trace_graph`, `ivy.transpile` or `ivy.unify`. Use this to reset or clear the graph cache if needed. Examples -------- >>> import ivy - >>> ivy.clear_graph_cache() - """ + >>> ivy.clear_graph_cache()""" + from ._compiler import clear_graph_cache as _clear_graph_cache return _clear_graph_cache() @@ -55,8 +55,8 @@ def graph_transpile( Returns ------- - Either a transpiled Graph or a non-initialized LazyGraph. - """ + Either a transpiled Graph or a non-initialized LazyGraph.""" + from ._compiler import graph_transpile as _graph_transpile return _graph_transpile( @@ -96,7 +96,6 @@ def source_to_source( e.g. (source="torch_frontend", target="ivy") or (source="torch_frontend", target="tensorflow") etc. Args: - ---- object: The object (class/function) to be translated. source (str, optional): The source framework. Defaults to 'torch'. target (str, optional): The target framework. Defaults to 'tensorflow'. @@ -107,9 +106,8 @@ def source_to_source( the old implementation. Defaults to 'True'. Returns: - ------- - The translated object. - """ + The translated object.""" + from ._compiler import source_to_source as _source_to_source return _source_to_source( @@ -140,8 +138,7 @@ def trace_graph( params_v=None, v=None ): - """Takes `fn` and traces it into a more efficient composition of backend - operations. + """Takes `fn` and traces it into a more efficient composition of backend operations. Parameters ---------- @@ -211,8 +208,8 @@ def trace_graph( >>> start = time.time() >>> graph(x) >>> print(time.time() - start) - 0.0001785755157470703 - """ + 0.0001785755157470703""" + from ._compiler import trace_graph as _trace_graph return _trace_graph( @@ -252,7 +249,6 @@ def transpile( e.g. (source="torch_frontend", target="ivy") or (source="torch_frontend", target="tensorflow") etc. Args: - ---- object: The object (class/function) to be translated. source (str, optional): The source framework. Defaults to 'torch'. target (str, optional): The target framework. Defaults to 'tensorflow'. @@ -263,9 +259,8 @@ def transpile( the old implementation. Defaults to 'True'. Returns: - ------- - The translated object. - """ + The translated object.""" + from ._compiler import transpile as _transpile return _transpile( diff --git a/ivy/functional/backends/tensorflow/general.py b/ivy/functional/backends/tensorflow/general.py index d44f3490f984..eb342e66e6df 100644 --- a/ivy/functional/backends/tensorflow/general.py +++ b/ivy/functional/backends/tensorflow/general.py @@ -61,7 +61,51 @@ def get_item( ) -> Union[tf.Tensor, tf.Variable]: if ivy.is_array(query) and ivy.is_bool_dtype(query) and not len(query.shape): return tf.expand_dims(x, 0) - return x[query] + if isinstance(query, (tf.Tensor, tf.Variable)): + if query.dtype == tf.bool: + return tf.boolean_mask(x, query, axis=0) + else: + query = tf.cast(query, tf.int64) + return tf.gather(x, query, axis=0) + else: + if any([isinstance(q, slice) for q in query]): + # convert any lists/tuples within the query to slices + query = tuple([slice(*q) if isinstance(q, (list, tuple)) else q for q in query]) + # for slices and other basic indexing, use __getitem__ + return x[query] + + +def set_item( + x: Union[tf.Tensor, tf.Variable], + query: Union[tf.Tensor, tf.Variable, Tuple], + val: Union[tf.Tensor, tf.Variable], + /, + *, + copy: Optional[bool] = False, +) -> Union[tf.Tensor, tf.Variable]: + # TODO: we should re-write this at some point so it's compatible with tf.function (don't use numpy as an intermediary) + # when doing this, be sure to check the performance of the function on large tensors, compared to this implementation + + if tf.is_tensor(x): + x = x.numpy() + if tf.is_tensor(val): + val = val.numpy() + + if isinstance(query, (tf.Tensor, tf.Variable)): + query = query.numpy() + elif isinstance(query, tuple): + query = tuple( + q.numpy() if isinstance(q, (tf.Tensor, tf.Variable)) else q + for q in query + ) + + x[query] = val + + if isinstance(x, tf.Variable) and not copy: + x.assign(x) + return x + else: + return tf.Variable(x) if isinstance(x, tf.Variable) else tf.convert_to_tensor(x) def to_numpy(x: Union[tf.Tensor, tf.Variable], /, *, copy: bool = True) -> np.ndarray: diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index fceb0fc4eca7..f4f966cd4a8f 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -2879,25 +2879,7 @@ def set_item( ivy.array([[ 0, -1, 20], [10, 10, 10]]) """ - if copy: - x = ivy.copy_array(x) - if not ivy.is_array(val): - val = ivy.array(val) - if 0 in x.shape or 0 in val.shape: - return x - if ivy.is_array(query) and ivy.is_bool_dtype(query): - if not len(query.shape): - query = ivy.tile(query, (x.shape[0],)) - indices = ivy.nonzero(query, as_tuple=False) - else: - indices, target_shape, _ = _parse_query( - query, ivy.shape(x, as_array=True), scatter=True - ) - if indices is None: - return x - val = val.astype(x.dtype) - ret = ivy.scatter_nd(indices, val, reduction="replace", out=x) - return ret + return current_backend(x).set_item(x, query, val, copy=copy) set_item.mixed_backend_wrappers = { diff --git a/ivy/utils/decorator_utils.py b/ivy/utils/decorator_utils.py index 177ae7669a8e..222be4139c5e 100644 --- a/ivy/utils/decorator_utils.py +++ b/ivy/utils/decorator_utils.py @@ -183,8 +183,6 @@ def handle_get_item(fn): def wrapper(inp, query, **kwargs): try: res = inp.__getitem__(query) - except IndexError: - raise except Exception: res = fn(inp, query, **kwargs) return res