Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Google Colab Error: optax is throwing an attribute error. #230

Open
prajjwalgeek opened this issue Jun 30, 2022 · 2 comments
Open

Google Colab Error: optax is throwing an attribute error. #230

prajjwalgeek opened this issue Jun 30, 2022 · 2 comments

Comments

@prajjwalgeek
Copy link

prajjwalgeek commented Jun 30, 2022

Optax Throws attribute error when using the attached Google Colab Inference Demo

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-14-a22d9a83aa66>](https://localhost:8080/#) in <module>()
      4 from jax.experimental import maps
      5 import numpy as np
----> 6 import optax
      7 import transformers
      8 

6 frames
[/usr/local/lib/python3.7/dist-packages/optax/__init__.py](https://localhost:8080/#) in <module>()
     15 """Optax: composable gradient processing and optimization, in JAX."""
     16 
---> 17 from optax._src.alias import adabelief
     18 from optax._src.alias import adafactor
     19 from optax._src.alias import adagrad

[/usr/local/lib/python3.7/dist-packages/optax/_src/alias.py](https://localhost:8080/#) in <module>()
     19 import jax.numpy as jnp
     20 
---> 21 from optax._src import base
     22 from optax._src import clipping
     23 from optax._src import combine

[/usr/local/lib/python3.7/dist-packages/optax/_src/base.py](https://localhost:8080/#) in <module>()
     16 
     17 from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union
---> 18 import chex
     19 
     20 # pylint:disable=no-value-for-parameter

[/usr/local/lib/python3.7/dist-packages/chex/__init__.py](https://localhost:8080/#) in <module>()
     15 """Chex: Testing made fun, in JAX!"""
     16 
---> 17 from chex._src.asserts import assert_axis_dimension
     18 from chex._src.asserts import assert_axis_dimension_comparator
     19 from chex._src.asserts import assert_axis_dimension_gt

[/usr/local/lib/python3.7/dist-packages/chex/_src/asserts.py](https://localhost:8080/#) in <module>()
     24 from unittest import mock
     25 
---> 26 from chex._src import asserts_internal as _ai
     27 from chex._src import pytypes
     28 import jax

[/usr/local/lib/python3.7/dist-packages/chex/_src/asserts_internal.py](https://localhost:8080/#) in <module>()
     30 
     31 from absl import logging
---> 32 from chex._src import pytypes
     33 import jax
     34 import jax.numpy as jnp

[/usr/local/lib/python3.7/dist-packages/chex/_src/pytypes.py](https://localhost:8080/#) in <module>()
     34 Scalar = Union[float, int]
     35 Numeric = Union[Array, Scalar]
---> 36 PRNGKey = jax.random.KeyArray
     37 PyTreeDef = type(jax.tree_structure(None))
     38 Shape = jax.core.Shape

AttributeError: module 'jax.random' has no attribute 'KeyArray'
@prajjwalgeek prajjwalgeek changed the title optax is throwing an attribute error. Google Colab Error: optax is throwing an attribute error. Jun 30, 2022
@neverix
Copy link

neverix commented Jul 5, 2022

I think I finally figured it out.

  1. !pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0 chex==0.0.6 jaxlib==0.3.7

#@title Patch 1
%%file /usr/local/lib/python3.7/dist-packages/chex/_src/pytypes.py
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pytypes for arrays and scalars."""

from typing import Any, Iterable, Mapping, Tuple, Union
import jax
import jax.numpy as jnp
import numpy as np

Array = jnp.ndarray
ArrayBatched = jax.interpreters.batching.BatchTracer
ArrayNumpy = np.ndarray
ArraySharded = jax.interpreters.pxla.ShardedDeviceArray
# Use this type for type annotation. For instance checking,  use
# `isinstance(x, jax.DeviceArray)`.
# `jax.interpreters.xla._DeviceArray` appears in jax > 0.2.5
if hasattr(jax.interpreters.xla, '_DeviceArray'):
  ArrayDevice = jax.interpreters.xla._DeviceArray  # pylint:disable=protected-access
else:
  ArrayDevice = jax.interpreters.xla.DeviceArray

Scalar = Union[float, int]
Numeric = Union[Array, Scalar]
PRNGKey = Array
Shape = Tuple[int, ...]

# CpuDevice = jax.lib.xla_extension.CpuDevice
GpuDevice = jax.lib.xla_extension.GpuDevice
TpuDevice = jax.lib.xla_extension.TpuDevice
Device = Union[GpuDevice, TpuDevice]

# As of 06/2020 pytype doesn't support recursive types (see b/109648354)
# pytype: disable=not-supported-yet
ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
#@title Patch 2
%%file /usr/local/lib/python3.7/dist-packages/chex/__init__.py
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Chex: Testing made fun, in JAX!"""

from chex._src.asserts import assert_axis_dimension
from chex._src.asserts import assert_axis_dimension_gt
from chex._src.asserts import assert_devices_available
from chex._src.asserts import assert_equal
from chex._src.asserts import assert_equal_rank
from chex._src.asserts import assert_equal_shape
from chex._src.asserts import assert_equal_shape_prefix
from chex._src.asserts import assert_equal_shape_suffix
from chex._src.asserts import assert_exactly_one_is_none
from chex._src.asserts import assert_gpu_available
from chex._src.asserts import assert_is_broadcastable
from chex._src.asserts import assert_max_traces
from chex._src.asserts import assert_not_both_none
from chex._src.asserts import assert_numerical_grads
from chex._src.asserts import assert_rank
from chex._src.asserts import assert_scalar
from chex._src.asserts import assert_scalar_in
from chex._src.asserts import assert_scalar_negative
from chex._src.asserts import assert_scalar_non_negative
from chex._src.asserts import assert_scalar_positive
from chex._src.asserts import assert_shape
from chex._src.asserts import assert_tpu_available
from chex._src.asserts import assert_tree_all_close
from chex._src.asserts import assert_tree_all_equal_comparator
from chex._src.asserts import assert_tree_all_equal_shapes
from chex._src.asserts import assert_tree_all_equal_structs
from chex._src.asserts import assert_tree_all_finite
from chex._src.asserts import assert_tree_no_nones
from chex._src.asserts import assert_tree_shape_prefix
from chex._src.asserts import assert_type
from chex._src.asserts import clear_trace_counter
from chex._src.asserts import if_args_not_none
from chex._src.dataclass import dataclass
from chex._src.dataclass import mappable_dataclass
from chex._src.fake import fake_jit
from chex._src.fake import fake_pmap
from chex._src.fake import fake_pmap_and_jit
from chex._src.fake import set_n_cpu_devices
from chex._src.pytypes import Array
from chex._src.pytypes import ArrayBatched
from chex._src.pytypes import ArrayDevice
from chex._src.pytypes import ArrayNumpy
from chex._src.pytypes import ArraySharded
from chex._src.pytypes import ArrayTree
# from chex._src.pytypes import CpuDevice
from chex._src.pytypes import Device
from chex._src.pytypes import GpuDevice
from chex._src.pytypes import Numeric
from chex._src.pytypes import PRNGKey
from chex._src.pytypes import Scalar
from chex._src.pytypes import Shape
from chex._src.pytypes import TpuDevice
from chex._src.variants import all_variants
from chex._src.variants import ChexVariantType
from chex._src.variants import params_product
from chex._src.variants import TestCase
from chex._src.variants import variants


__version__ = "0.0.6"

__all__ = (
    "all_variants",
    "Array",
    "ArrayBatched",
    "ArrayDevice",
    "ArrayNumpy",
    "ArraySharded",
    "ArrayTree",
    "assert_axis_dimension",
    "assert_axis_dimension_gt",
    "assert_devices_available",
    "assert_equal",
    "assert_equal_rank",
    "assert_equal_shape",
    "assert_equal_shape_prefix",
    "assert_equal_shape_suffix",
    "assert_exactly_one_is_none",
    "assert_gpu_available",
    "assert_is_broadcastable",
    "assert_max_traces",
    "assert_not_both_none",
    "assert_numerical_grads",
    "assert_rank",
    "assert_scalar",
    "assert_scalar_in",
    "assert_scalar_negative",
    "assert_scalar_non_negative",
    "assert_scalar_positive",
    "assert_shape",
    "assert_tpu_available",
    "assert_tree_all_close",
    "assert_tree_all_equal_comparator",
    "assert_tree_all_equal_shapes",
    "assert_tree_all_equal_structs",
    "assert_tree_all_finite",
    "assert_tree_no_nones",
    "assert_tree_shape_prefix",
    "assert_type",
    "ChexVariantType",
    "clear_trace_counter",
    "CpuDevice",
    "dataclass",
    "Device",
    "fake_jit",
    "fake_pmap",
    "fake_pmap_and_jit",
    "GpuDevice",
    "if_args_not_none",
    "mappable_dataclass",
    "Numeric",
    "params_product",
    "PRNGKey",
    "Scalar",
    "set_n_cpu_devices",
    "Shape",
    "TestCase",
    "TpuDevice",
    "variants",
)

With these few patches, it seems to work on Colab's TPU.

@vfbd
Copy link
Contributor

vfbd commented Jul 8, 2022

#221

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants