From 67d382bf433d94fb5a68b6a51787d2e132e76597 Mon Sep 17 00:00:00 2001 From: Nicolas Dickreuter Date: Sun, 25 Feb 2024 01:34:34 +0000 Subject: [PATCH] fix for correctly moving into next street --- gym_env/cycle.py | 14 +++++++------- gym_env/env.py | 4 ++-- tests/test_gym_env.py | 34 +++++++++++++++++++--------------- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/gym_env/cycle.py b/gym_env/cycle.py index 8013b64..f4e0446 100644 --- a/gym_env/cycle.py +++ b/gym_env/cycle.py @@ -44,14 +44,15 @@ def new_hand_reset(self): self.folder = [False] * len(self.lst) self.step_counter = 0 - def new_round_reset(self): + def new_street_reset(self): """Reset the state for the next stage: flop, turn or river""" self.step_counter = 0 self.round_number_in_street = 0 self.idx = self.dealer_idx self.last_raiser_step = len(self.lst) self.checkers = 0 - self.max_remaining_steps_without_raising = len(self.alive) + self.max_remaining_steps_without_raising = len(self.alive) - 1 + self.last_raiser = None def next_player(self, step=1): """Switch to the next player in the round.""" @@ -64,12 +65,12 @@ def next_player(self, step=1): self.idx %= len(self.lst) if self.step_counter > len(self.lst): self.round_number_in_street += 1 - if self.max_steps_total and (self.step_counter >= self.max_steps_total): + if self.max_steps_total and (self.step_counter > self.max_steps_total): log.info("Max steps total has been reached") return False if self.last_raiser: - if self.step_counter >= self.last_raiser + self.max_remaining_steps_without_raising: + if self.step_counter > self.last_raiser + self.max_remaining_steps_without_raising: log.info("Max steps without raising has been reached. For example all calls after raiser.") return False @@ -133,8 +134,7 @@ def mark_folder(self): def mark_raiser(self): """Mark a raise for the current player.""" - if self.step_counter > 2: - self.last_raiser = self.step_counter + self.last_raiser = self.step_counter def mark_checker(self): """Counter the number of checks in the round""" @@ -148,7 +148,7 @@ def mark_out_of_cash_but_contributed(self): def mark_bb(self): """Ensure bb can raise""" self.last_raiser_step = self.step_counter + len(self.lst) - self.max_steps_total = self.step_counter + len(self.lst) * self.max_raises_per_player_round + self.max_steps_total = self.step_counter + len(self.lst) * self.max_raises_per_player_round + 2 def is_raising_allowed(self): """Check if raising is still allowed at this position""" diff --git a/gym_env/env.py b/gym_env/env.py index 3aab590..f850ac8 100644 --- a/gym_env/env.py +++ b/gym_env/env.py @@ -385,7 +385,7 @@ def _process_decision(self, action): # pylint: disable=too-many-statements else: raise RuntimeError("Illegal action.") - if contribution > self.min_call: + if contribution > self.min_call and not (action==Action.BIG_BLIND or action==Action.SMALL_BLIND): self.player_cycle.mark_raiser() self.current_player.stack -= contribution @@ -504,7 +504,7 @@ def _initiate_round(self): self.min_call = 0 for player in self.players: player.last_action_in_stage = '' - self.player_cycle.new_round_reset() + self.player_cycle.new_street_reset() if self.stage == Stage.PREFLOP: log.info("") diff --git a/tests/test_gym_env.py b/tests/test_gym_env.py index fef4697..49e5309 100644 --- a/tests/test_gym_env.py +++ b/tests/test_gym_env.py @@ -205,18 +205,6 @@ def test_cycle_mechanism1(): assert current == 'utg1' -def test_cycle_mechanism2(): - """Test cycle""" - lst = ['dealer', 'sb', 'bb', 'utg'] - cycle = PlayerCycle(lst, start_idx=2, max_steps_total=5) - current = cycle.next_player() - assert current == 'utg' - cycle.next_player() - cycle.next_player() - current = cycle.next_player(step=2) - assert not current - - class PlayerForTest: """Player shell""" @@ -274,6 +262,8 @@ def test_unlimited_raising_preflop(): env.step(Action.RAISE_POT) # bb raises assert env.stage == Stage.PREFLOP env.step(Action.RAISE_POT) # sb calls + assert env.stage == Stage.PREFLOP + env.step(Action.CALL) # sb calls assert env.stage == Stage.FLOP @@ -289,21 +279,35 @@ def test_end_preflop_on_call(): assert env.stage == Stage.FLOP -@pytest.mark.skip("Requires further discussion") def test_preflop_call_after_max_raises(): """Test that the preflop round ends when there is a call after a raise """ env = _create_env(2, initial_stacks=100000, max_raises_per_player_round=2) + # sb + # bb env.step(Action.CALL) # sb env.step(Action.RAISE_POT) # bb raises env.step(Action.RAISE_POT) # sb raises + assert env.stage == Stage.PREFLOP env.step(Action.RAISE_POT) # bb raises + assert env.stage == Stage.PREFLOP env.step(Action.RAISE_POT) # sb raises - + assert env.stage == Stage.PREFLOP # Now we should still be in preflop, but raises are no longer legal actions # Only a Call or Fold would end the round - assert env.stage == Stage.PREFLOP + assert env.legal_moves == [Action.CALL, Action.FOLD] + + env.step(Action.CALL) + assert env.stage == Stage.FLOP + env.step(Action.RAISE_POT) + env.step(Action.CALL) + assert env.stage == Stage.TURN + env.step(Action.RAISE_POT) + env.step(Action.CALL) + assert env.stage == Stage.RIVER + env.step(Action.RAISE_POT) + env.step(Action.CALL) def test_one_max_raise_per_player():