Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Style]: flake8 violation E721 #48084

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/_private/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,7 +1608,7 @@ def start_raylet(
Returns:
ProcessInfo for the process that was started.
"""
assert node_manager_port is not None and type(node_manager_port) == int
assert node_manager_port is not None and isinstance(node_manager_port, int)

if use_valgrind and use_profiler:
raise ValueError("Cannot use valgrind and profiler at the same time.")
Expand Down
2 changes: 1 addition & 1 deletion python/ray/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ def wait_until_succeeded_without_exception(
Return:
Whether exception occurs within a timeout.
"""
if type(exceptions) != tuple:
if not isinstance(exceptions, tuple):
raise Exception("exceptions arguments should be given as a tuple")

time_elapsed = 0
Expand Down
2 changes: 1 addition & 1 deletion python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ def _preprocess(self) -> None:
"the driver cannot participate in the NCCL group"
)

if type(dag_node.type_hint) == ChannelOutputType:
if isinstance(dag_node.type_hint, ChannelOutputType):
# No type hint specified by the user. Replace
# with the default type hint for this DAG.
dag_node.with_type_hint(self._default_type_hint)
Expand Down
14 changes: 7 additions & 7 deletions python/ray/dashboard/tests/test_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def test_immutable_types():
d["list"][0] = {str(i): i for i in range(1000)}
d["dict"] = {str(i): i for i in range(1000)}
immutable_dict = dashboard_utils.make_immutable(d)
assert type(immutable_dict) == dashboard_utils.ImmutableDict
assert isinstance(immutable_dict, dashboard_utils.ImmutableDict)
assert immutable_dict == dashboard_utils.ImmutableDict(d)
assert immutable_dict == d
assert dashboard_utils.ImmutableDict(immutable_dict) == immutable_dict
Expand All @@ -801,8 +801,8 @@ def test_immutable_types():
# Test json dumps / loads
json_str = json.dumps(immutable_dict, cls=dashboard_optional_utils.CustomEncoder)
deserialized_immutable_dict = json.loads(json_str)
assert type(deserialized_immutable_dict) == dict
assert type(deserialized_immutable_dict["list"]) == list
assert isinstance(deserialized_immutable_dict, dict)
assert isinstance(deserialized_immutable_dict["list"], list)
assert immutable_dict.mutable() == deserialized_immutable_dict
dashboard_optional_utils.rest_response(True, "OK", data=immutable_dict)
dashboard_optional_utils.rest_response(True, "OK", **immutable_dict)
Expand All @@ -815,12 +815,12 @@ def test_immutable_types():

# Test get default immutable
immutable_default_value = immutable_dict.get("not exist list", [1, 2])
assert type(immutable_default_value) == dashboard_utils.ImmutableList
assert isinstance(immutable_default_value, dashboard_utils.ImmutableList)

# Test recursive immutable
assert type(immutable_dict["list"]) == dashboard_utils.ImmutableList
assert type(immutable_dict["dict"]) == dashboard_utils.ImmutableDict
assert type(immutable_dict["list"][0]) == dashboard_utils.ImmutableDict
assert isinstance(immutable_dict["list"], dashboard_utils.ImmutableList)
assert isinstance(immutable_dict["dict"], dashboard_utils.ImmutableDict)
assert isinstance(immutable_dict["list"][0], dashboard_utils.ImmutableDict)

# Test exception
with pytest.raises(TypeError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def _cast_large_list_to_list(batch: pyarrow.Table):

for column_name in old_schema.names:
field_type = old_schema.field(column_name).type
if type(field_type) == pyarrow.lib.LargeListType:
if isinstance(field_type, pyarrow.lib.LargeListType):
value_type = field_type.value_type

if value_type == pyarrow.large_binary():
Expand Down
5 changes: 3 additions & 2 deletions python/ray/data/tests/test_numpy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ def do_map_batches(data):


def assert_structure_equals(a, b):
assert type(a) == type(b), (type(a), type(b))
assert type(a[0]) == type(b[0]), (type(a[0]), type(b[0])) # noqa: E721
# TODO: Remove noqa after flake8 being upgraded to 7.1.1
assert type(a) is type(b), (type(a), type(b)) # noqa E721
assert type(a[0]) is type(b[0]), (type(a[0]), type(b[0])) # noqa E721
assert a.dtype == b.dtype
assert a.shape == b.shape
for i in range(len(a)):
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/multiplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ async def load_model(self, model_id: str) -> Any:
The user-constructed model object.
"""

if type(model_id) != str:
if not isinstance(model_id, str):
raise TypeError("The model ID must be a string.")

if not model_id:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tests/gcp/test_gcp_tpu_command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def test_max_active_connections_env_var():
cmd_runner = TPUCommandRunner(**args)
os.environ[ray_constants.RAY_TPU_MAX_CONCURRENT_CONNECTIONS_ENV_VAR] = "1"
num_connections = cmd_runner.num_connections
assert type(num_connections) == int
assert isinstance(num_connections, int)
assert num_connections == 1


Expand Down
3 changes: 2 additions & 1 deletion python/ray/tests/modin/modin_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def df_equals(df1, df2):
if isinstance(df1, pandas.DataFrame) and isinstance(df2, pandas.DataFrame):
if (df1.empty and not df2.empty) or (df2.empty and not df1.empty):
assert False, "One of the passed frames is empty, when other isn't"
elif df1.empty and df2.empty and type(df1) != type(df2):
# TODO: Remove noqa after flake8 being upgraded to 7.1.1
elif df1.empty and df2.empty and type(df1) is not type(df2): # noqa E721
assert (
False
), f"Empty frames have different types: {type(df1)} != {type(df2)}"
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tests/test_client_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_client(address):
if address in ("local", None):
assert isinstance(builder, client_builder._LocalClientBuilder)
else:
assert type(builder) == client_builder.ClientBuilder
assert isinstance(builder, client_builder.ClientBuilder)
assert builder.address == address.replace("ray://", "")


Expand Down
6 changes: 3 additions & 3 deletions python/ray/tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def is_named_tuple(cls):
f = getattr(cls, "_fields", None)
if not isinstance(f, tuple):
return False
return all(type(n) == str for n in f)
return all(isinstance(n, str) for n in f)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -95,8 +95,8 @@ def f(x):
# TODO(rkn): The numpy dtypes currently come back as regular integers
# or floats.
if type(obj).__module__ != "numpy":
assert type(obj) == type(new_obj_1)
assert type(obj) == type(new_obj_2)
assert isinstance(obj, type(new_obj_1))
assert isinstance(obj, type(new_obj_2))


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def assertDictAlmostEqual(a, b):
assert k in b, f"Key {k} not found in {b}"
w = b[k]

assert type(v) == type(w), f"Type {type(v)} is not {type(w)}"
assert isinstance(v, type(w)), f"Type {type(v)} is not {type(w)}"

if isinstance(v, dict):
assert assertDictAlmostEqual(v, w), f"Subdict {v} != {w}"
Expand Down
4 changes: 2 additions & 2 deletions rllib/connectors/action/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def to_state(self):

@staticmethod
def from_state(ctx: ConnectorContext, params: Any):
assert (
type(params) == list
assert isinstance(
params, list
), "ActionConnectorPipeline takes a list of connector params."
connectors = []
for state in params:
Expand Down
4 changes: 2 additions & 2 deletions rllib/connectors/agent/clip_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __init__(self, ctx: ConnectorContext, sign=False, limit=None):

def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
d = ac_data.data
assert (
type(d) == dict
assert isinstance(
d, dict
), "Single agent data must be of type Dict[str, TensorStructType]"

if SampleBatch.REWARDS not in d:
Expand Down
4 changes: 2 additions & 2 deletions rllib/connectors/agent/mean_std_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def __init__(

def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
d = ac_data.data
assert (
type(d) == dict
assert isinstance(
d, dict
), "Single agent data must be of type Dict[str, TensorStructType]"
if SampleBatch.OBS in d:
d[SampleBatch.OBS] = self.filter(
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/obs_preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def is_identity(self):

def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
d = ac_data.data
assert type(d) == dict, (
assert isinstance(d, dict), (
"Single agent data must be of type Dict[str, TensorStructType] but is of "
"type {}".format(type(d))
)
Expand Down
4 changes: 2 additions & 2 deletions rllib/connectors/agent/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def to_state(self):

@staticmethod
def from_state(ctx: ConnectorContext, params: List[Any]):
assert (
type(params) == list
assert isinstance(
params, list
), "AgentConnectorPipeline takes a list of connector params."
connectors = []
for state in params:
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_agent_connectors_from_config(
clip_rewards = __clip_rewards(config)
if clip_rewards is True:
connectors.append(ClipRewardAgentConnector(ctx, sign=True))
elif type(clip_rewards) == float:
elif isinstance(clip_rewards, float):
connectors.append(ClipRewardAgentConnector(ctx, limit=abs(clip_rewards)))

if __preprocessing_enabled(config):
Expand Down
4 changes: 2 additions & 2 deletions rllib/env/wrappers/dm_control_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def _spec_to_box(spec):
def extract_min_max(s):
assert s.dtype == np.float64 or s.dtype == np.float32
dim = np.int_(np.prod(s.shape))
if type(s) == specs.Array:
if isinstance(s, specs.Array):
bound = np.inf * np.ones(dim, dtype=np.float32)
return -bound, bound
elif type(s) == specs.BoundedArray:
elif isinstance(s, specs.BoundedArray):
zeros = np.zeros(dim, dtype=np.float32)
return s.minimum + zeros, s.maximum + zeros

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_mixin_sampling_episodes(self):
for _ in range(20):
buffer.add(batch)
sample = buffer.sample(2)
assert type(sample) == MultiAgentBatch
assert isinstance(sample, MultiAgentBatch)
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
# One sample in the episode does not belong the the episode on thus
# gets dropped. Full episodes are of length two.
Expand All @@ -88,7 +88,7 @@ def test_mixin_sampling_sequences(self):
for _ in range(400):
buffer.add(batch)
sample = buffer.sample(10)
assert type(sample) == MultiAgentBatch
assert isinstance(sample, MultiAgentBatch)
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 2 * len(batch), delta=0.1)

Expand All @@ -113,7 +113,7 @@ def test_mixin_sampling_timesteps(self):
buffer.add(batch)
buffer.add(batch)
sample = buffer.sample(3)
assert type(sample) == MultiAgentBatch
assert isinstance(sample, MultiAgentBatch)
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 3.0, delta=0.2)

Expand All @@ -125,7 +125,7 @@ def test_mixin_sampling_timesteps(self):
for _ in range(100):
buffer.add(batch)
sample = buffer.sample(5)
assert type(sample) == MultiAgentBatch
assert isinstance(sample, MultiAgentBatch)
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 1.5, delta=0.2)

Expand All @@ -142,7 +142,7 @@ def test_mixin_sampling_timesteps(self):
for _ in range(100):
buffer.add(batch)
sample = buffer.sample(10)
assert type(sample) == MultiAgentBatch
assert isinstance(sample, MultiAgentBatch)
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 10.0, delta=0.2)

Expand All @@ -156,12 +156,12 @@ def test_mixin_sampling_timesteps(self):
buffer.add(batch)
# Expect exactly 1 batch to be returned.
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert isinstance(sample, MultiAgentBatch)
self.assertTrue(len(sample) == 1)
# Expect exactly 0 sample to be returned (nothing new to be returned;
# no replay allowed (replay_ratio=0.0)).
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert isinstance(sample, MultiAgentBatch)
assert len(sample.policy_batches) == 0
# If we insert and replay n times, expect roughly return batches of
# len 1 (replay_ratio=0.0 -> 0% replayed samples -> 1 new and 0 old samples
Expand All @@ -170,7 +170,7 @@ def test_mixin_sampling_timesteps(self):
for _ in range(100):
buffer.add(batch)
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert isinstance(sample, MultiAgentBatch)
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 1.0, delta=0.2)

Expand All @@ -187,19 +187,19 @@ def test_mixin_sampling_timesteps(self):
buffer.add(batch)
# Expect exactly 1 sample to be returned (the new batch).
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert isinstance(sample, MultiAgentBatch)
self.assertTrue(len(sample) == 1)
# Another replay -> Expect exactly 1 sample to be returned.
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert isinstance(sample, MultiAgentBatch)
self.assertTrue(len(sample) == 1)
# If we replay n times, expect roughly return batches of
# len 1 (replay_ratio=1.0 -> 100% replayed samples -> 0 new and 1 old samples
# on average in each returned value).
results = []
for _ in range(100):
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert isinstance(sample, MultiAgentBatch)
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 1.0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_update_priorities(self):

# Fetch records, their indices and weights.
mabatch = buffer.sample(3)
assert type(mabatch) == MultiAgentBatch
assert isinstance(mabatch, MultiAgentBatch)
samplebatch = mabatch.policy_batches[DEFAULT_POLICY_ID]

weights = samplebatch["weights"]
Expand All @@ -211,9 +211,9 @@ def test_update_priorities(self):
# (which still has a weight of 1.0).
for _ in range(10):
mabatch = buffer.sample(1000)
assert type(mabatch) == MultiAgentBatch
assert isinstance(mabatch, MultiAgentBatch)
samplebatch = mabatch.policy_batches[DEFAULT_POLICY_ID]
assert type(mabatch) == MultiAgentBatch
assert isinstance(mabatch, MultiAgentBatch)
indices = samplebatch["batch_indexes"]
self.assertTrue(1900 < np.sum(indices) < 2200)
# Test get_state/set_state.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class LinearDiscreteEnv(gym.Env):

def __init__(self, config=None):
self.config = copy.copy(self.DEFAULT_CONFIG_LINEAR)
if config is not None and type(config) == dict:
if config is not None and isinstance(config, dict):
self.config.update(config)

self.feature_dim = self.config["feature_dim"]
Expand Down Expand Up @@ -128,7 +128,7 @@ class WheelBanditEnv(gym.Env):

def __init__(self, config=None):
self.config = copy.copy(self.DEFAULT_CONFIG_WHEEL)
if config is not None and type(config) == dict:
if config is not None and isinstance(config, dict):
self.config.update(config)

self.delta = self.config["delta"]
Expand Down
8 changes: 4 additions & 4 deletions rllib_contrib/dt/tests/test_segmentation_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ def _get_internal_buffer(
"""Get the internal buffer list from the buffer. If MultiAgent then return the
internal buffer corresponding to the given policy_id.
"""
if type(buffer) == SegmentationBuffer:
if isinstance(buffer, SegmentationBuffer):
return buffer._buffer
elif type(buffer) == MultiAgentSegmentationBuffer:
elif isinstance(buffer, MultiAgentSegmentationBuffer):
return buffer.buffers[policy_id]._buffer
else:
raise NotImplementedError
Expand All @@ -104,9 +104,9 @@ def _as_sample_batch(
"""Returns a SampleBatch. If MultiAgentBatch then return the SampleBatch
corresponding to the given policy_id.
"""
if type(batch) == SampleBatch:
if isinstance(batch, SampleBatch):
return batch
elif type(batch) == MultiAgentBatch:
elif isinstance(batch, MultiAgentBatch):
return batch.policy_batches[policy_id]
else:
raise NotImplementedError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def forward(self, input_dict, state, seq_lens):
print(input_dict)
raise Exception("No observation in input_dict")
if self.alpha_zero_obs:
if not type(obs) == torch.Tensor:
if not isinstance(obs, torch.Tensor):
obs = torch.from_numpy(obs.astype(np.float32))
action_mask = torch.from_numpy(action_mask.astype(np.float32))
try:
Expand Down
Loading