diff --git a/.github/actions/install-concordia/action.yml b/.github/actions/install-concordia/action.yml new file mode 100644 index 00000000..5c6fef6a --- /dev/null +++ b/.github/actions/install-concordia/action.yml @@ -0,0 +1,69 @@ +name: install-concordia + +inputs: + python-version: + description: Python version + required: false + default: '3.11' + type: string + +runs: + using: composite + steps: + - name: Get current runner + id: os-info + shell: bash + run: | + if [ "${RUNNER_OS}" = 'macOS' ]; then + echo "name=$(sw_vers -productName)" >> $GITHUB_OUTPUT + echo "version=$(sw_vers -productVersion)" >> $GITHUB_OUTPUT + elif [ "${RUNNER_OS}" = 'Linux' ]; then + echo "name=$(lsb_release -i -s)" >> $GITHUB_OUTPUT + echo "version=$(lsb_release -r -s)" >> $GITHUB_OUTPUT + else + exit 1 + fi + + - name: Set up Python ${{ inputs.python-version }} + uses: actions/setup-python@61a6322f88396a6271a6ee3565807d608ecaddd1 + with: + python-version: ${{ inputs.python-version }} + cache: 'pip' + cache-dependency-path: setup.py + + - name: Restore Concordia installation + id: restore + uses: actions/cache/restore@v3 + with: + path: | + concordia/assets + venv + key: install-concordia-${{ steps.os-info.outputs.name }}-${{ steps.os-info.outputs.version }}-py${{ inputs.python-version}}-${{ hashFiles('setup.py') }} + restore-keys: | + install-concordia-${{ steps.os-info.outputs.name }}-${{ steps.os-info.outputs.version }}-py${{ inputs.python-version }}- + + - name: Install Concordia + if: steps.restore.outputs.cache-hit != 'true' + shell: bash + run: | + pip install --upgrade pip + pip install virtualenv + virtualenv venv + source venv/bin/activate + pip install --editable .[dev] + + - name: Save Concordia installation + if: steps.restore.outputs.cache-hit != 'true' + uses: actions/cache/save@v3 + with: + path: | + concordia/assets + venv + key: ${{ steps.restore.outputs.cache-primary-key }} + + - name: Activate virtual environment + shell: bash + run: | + source venv/bin/activate + pip list + echo "PATH=${PATH}" >> $GITHUB_ENV diff --git a/.github/actions/install-examples/action.yml b/.github/actions/install-examples/action.yml new file mode 100644 index 00000000..17b43a0b --- /dev/null +++ b/.github/actions/install-examples/action.yml @@ -0,0 +1,15 @@ +name: install-examples + +runs: + using: composite + steps: + - name: Install Concordia + uses: ./.github/actions/install-concordia + + - name: Install requirements for examples + shell: bash + run: pip install -r examples/requirements.txt + + - name: Show installed dependencies + shell: bash + run: pip list diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..e9335622 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,19 @@ +version: 2 +updates: + - package-ecosystem: pip + directory: / + schedule: + interval: monthly + + - package-ecosystem: github-actions + directory: / + schedule: + interval: monthly + + - package-ecosystem: docker + directory: /.devcontainer + schedule: + interval: monthly + ignore: + - dependency-name: "vscode/devcontainers/python" + versions: [">= 3.11"] diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 00000000..c98a953e --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,62 @@ +name: "CodeQL" + +on: + push: + branches: [ "main" ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ "main" ] + schedule: + - cron: '36 13 * * 4' + +# Declare default permissions as read only. +permissions: read-all + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'python' ] + + steps: + - name: Checkout repository + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@74483a38d39275f33fcff5f35b679b5ca4a26a99 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + + # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs + # queries: security-extended,security-and-quality + + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@74483a38d39275f33fcff5f35b679b5ca4a26a99 + + # ℹī¸ Command-line programs to run using the OS shell. + # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun + + # If the Autobuild fails above, remove it and uncomment the following three lines. + # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. + + # - run: | + # echo "Run, Build Application using script" + # ./location_of_script_within_repo/buildscript.sh + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@74483a38d39275f33fcff5f35b679b5ca4a26a99 diff --git a/.github/workflows/pylint-concordia.yml b/.github/workflows/pylint-concordia.yml new file mode 100644 index 00000000..ae6e2f2d --- /dev/null +++ b/.github/workflows/pylint-concordia.yml @@ -0,0 +1,45 @@ +name: pylint-concordia + +on: + push: + branches: + - main + paths: + - '.github/actions/install-concordia/action.yml' + - '.github/workflows/pylint-concordia.yml' + - '.pylintrc' + - 'examples/**' + - 'concordia/**' + - 'setup.py' + pull_request: + branches: + - main + paths: + - '.github/actions/install-concordia/action.yml' + - '.github/workflows/pylint-concordia.yml' + - '.pylintrc' + - 'examples/**' + - 'concordia/**' + - 'setup.py' + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +permissions: read-all + +jobs: + pylint: + name: Lint Concordia + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - name: Checkout Concordia + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 + + - name: Install Concordia + uses: ./.github/actions/install-concordia + + - name: Run PyLint on Concordia + run: pylint --errors-only concordia diff --git a/.github/workflows/pylint-examples.yml b/.github/workflows/pylint-examples.yml new file mode 100644 index 00000000..5426355d --- /dev/null +++ b/.github/workflows/pylint-examples.yml @@ -0,0 +1,47 @@ +name: pylint-examples + +on: + push: + branches: + - main + paths: + - '.github/actions/install-examples/action.yml' + - '.github/actions/install-concordia/action.yml' + - '.github/workflows/pylint-examples.yml' + - '.pylintrc' + - 'examples/**' + - 'concordia/**' + - 'setup.py' + pull_request: + branches: + - main + paths: + - '.github/actions/install-examples/action.yml' + - '.github/actions/install-concordia/action.yml' + - '.github/workflows/pylint-examples.yml' + - '.pylintrc' + - 'examples/**' + - 'concordia/**' + - 'setup.py' + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +permissions: read-all + +jobs: + pylint: + name: Lint examples + runs-on: ubuntu-latest + timeout-minutes: 5 + steps: + - name: Checkout Concordia + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 + + - name: Install examples + uses: ./.github/actions/install-examples + + - name: Run PyLint on examples + run: pylint --errors-only examples diff --git a/.github/workflows/pypi-publish.yml b/.github/workflows/pypi-publish.yml new file mode 100644 index 00000000..484d9dee --- /dev/null +++ b/.github/workflows/pypi-publish.yml @@ -0,0 +1,76 @@ +# A workflow to publish releases to PyPi and TestPyPi. + +name: pypi-publish + +on: + release: + types: [published] + workflow_dispatch: + inputs: + test_sdist: + description: 'Test the sdist before uploading' + type: boolean + default: true + upload_to_test_pypi: + description: 'Upload to Test PyPi' + type: boolean + default: true + upload_to_pypi: + description: 'Upload to PyPi' + type: boolean + default: false + +permissions: read-all + +jobs: + pypi-publish: + name: Upload to PyPI + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/dm-concordia + permissions: + id-token: write + timeout-minutes: 90 + + steps: + - name: Checkout Concordia + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 + + - name: Set up Python + uses: actions/setup-python@65d7f2d534ac1bc67fcd62888c5f4f3d2cb2b236 + with: + python-version: '3.11' + + - name: Install Python dependencies + run: | + pip install --upgrade pip + pip install build + + - name: Build source distribution + run: python -m build --sdist --outdir dist/ + + - name: Install from source distribution + if: github.event_name == 'release' || inputs.test_sdist + run: | + pip install setuptools + pip -vvv install dist/*.tar.gz + + - name: Test installation + if: github.event_name == 'release' || inputs.test_sdist + run: | + pip install pytest-xdist + pytest -n auto -rax --pyargs concordia + + - name: Publish to TestPyPI + if: github.event_name == 'release' || inputs.upload_to_test_pypi + uses: pypa/gh-action-pypi-publish@b7f401de30cb6434a1e19f805ff006643653240e + with: + repository-url: https://test.pypi.org/legacy/ + verbose: true + + - name: Publish to PyPI + if: github.event_name == 'release' || inputs.upload_to_pypi + uses: pypa/gh-action-pypi-publish@b7f401de30cb6434a1e19f805ff006643653240e + with: + verbose: true diff --git a/.github/workflows/pypi-test.yml b/.github/workflows/pypi-test.yml new file mode 100644 index 00000000..6c2d01aa --- /dev/null +++ b/.github/workflows/pypi-test.yml @@ -0,0 +1,63 @@ +# Continuous integration tests. + +name: pypi-test + +on: + schedule: + - cron: "0 2 * * 1" # Every Monday at 2am. + push: + branches: + - main + paths: + - '.github/workflows/pypi-test.yml' + pull_request: + branches: + - main + paths: + - '.github/workflows/pypi-test.yml' + workflow_run: + workflows: + - pypi-publish + types: + - completed + workflow_dispatch: + +permissions: read-all + +jobs: + pypi-test: + name: Test PyPI Distribution + if: ${{ github.event.workflow_run.conclusion != 'failure' }} + runs-on: ${{ matrix.os }} + env: + SYSTEM_VERSION_COMPAT: 0 # See https://github.com/actions/setup-python/issues/279. + timeout-minutes: 120 + strategy: + fail-fast: false + matrix: + os: + - macos-11 + - macos-12 + - ubuntu-20.04 + - ubuntu-22.04 + python-version: + - '3.11' + + steps: + - name: Set up Python ${{ matrix.python_version }} + uses: actions/setup-python@65d7f2d534ac1bc67fcd62888c5f4f3d2cb2b236 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Python dependencies + run: | + pip install --upgrade pip + pip install pytest-xdist setuptools + + - name: Install from PyPI + run: | + pip -vvv install dm-concordia + pip list + + - name: Test installation + run: pytest --pyargs concordia diff --git a/.github/workflows/pytype-concordia.yml b/.github/workflows/pytype-concordia.yml new file mode 100644 index 00000000..eb05e0d5 --- /dev/null +++ b/.github/workflows/pytype-concordia.yml @@ -0,0 +1,45 @@ +name: pytype-concordia + +on: + push: + branches: + - main + paths: + - '.github/actions/install-concordia/action.yml' + - '.github/workflows/pytype-concordia.yml' + - 'examples/**' + - 'concordia/**' + - 'pyproject.toml' + - 'setup.py' + pull_request: + branches: + - main + paths: + - '.github/actions/install-concordia/action.yml' + - '.github/workflows/pytype-concordia.yml' + - 'examples/**' + - 'concordia/**' + - 'pyproject.toml' + - 'setup.py' + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +permissions: read-all + +jobs: + pytype: + name: Typecheck Concordia + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - name: Checkout Concordia + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 + + - name: Install Concordia + uses: ./.github/actions/install-concordia + + - name: Run PyType on Concordia + run: pytype concordia diff --git a/.github/workflows/pytype-examples.yml b/.github/workflows/pytype-examples.yml new file mode 100644 index 00000000..c744e5c3 --- /dev/null +++ b/.github/workflows/pytype-examples.yml @@ -0,0 +1,47 @@ +name: pytype-examples + +on: + push: + branches: + - main + paths: + - '.github/actions/install-examples/action.yml' + - '.github/actions/install-concordia/action.yml' + - '.github/workflows/pytype-examples.yml' + - 'examples/**' + - 'concordia/**' + - 'pyproject.toml' + - 'setup.py' + pull_request: + branches: + - main + paths: + - '.github/actions/install-examples/action.yml' + - '.github/actions/install-concordia/action.yml' + - '.github/workflows/pytype-examples.yml' + - 'examples/**' + - 'concordia/**' + - 'pyproject.toml' + - 'setup.py' + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +permissions: read-all + +jobs: + pytype: + name: Typecheck examples + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - name: Checkout Concordia + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 + + - name: Install examples + uses: ./.github/actions/install-examples + + - name: Run PyType on examples + run: pytype examples diff --git a/.github/workflows/scorecards-analysis.yml b/.github/workflows/scorecards-analysis.yml new file mode 100644 index 00000000..da99853a --- /dev/null +++ b/.github/workflows/scorecards-analysis.yml @@ -0,0 +1,62 @@ +name: Scorecards supply-chain security +on: + # Only the default branch is supported. + branch_protection_rule: + schedule: + - cron: '17 10 * * 0' + push: + branches: [ "main" ] + +# Declare default permissions as read only. +permissions: read-all + +jobs: + analysis: + name: Scorecards analysis + runs-on: ubuntu-latest + permissions: + # Needed to upload the results to code-scanning dashboard. + security-events: write + # Used to receive a badge. (Upcoming feature) + id-token: write + # Needs for private repositories. + contents: read + actions: read + + steps: + - name: "Checkout code" + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 + with: + persist-credentials: false + + - name: "Run analysis" + uses: ossf/scorecard-action@0864cf19026789058feabb7e87baa5f140aac736 + with: + results_file: results.sarif + results_format: sarif + # (Optional) Read-only PAT token. Uncomment the `repo_token` line below if: + # - you want to enable the Branch-Protection check on a *public* repository, or + # - you are installing Scorecards on a *private* repository + # To create the PAT, follow the steps in https://github.com/ossf/scorecard-action#authentication-with-pat. + # repo_token: ${{ secrets.SCORECARD_READ_TOKEN }} + + # Publish the results for public repositories to enable scorecard badges. For more details, see + # https://github.com/ossf/scorecard-action#publishing-results. + # For private repositories, `publish_results` will automatically be set to `false`, regardless + # of the value entered here. + publish_results: true + + # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF + # format to the repository Actions tab. + - name: "Upload artifact" + uses: actions/upload-artifact@a8a3f3ad30e3422c9c7b888a15615d19a852ae32 + with: + name: SARIF file + path: results.sarif + retention-days: 5 + + # Upload the results to GitHub's code scanning dashboard. + - name: "Upload to code-scanning" + uses: github/codeql-action/upload-sarif@74483a38d39275f33fcff5f35b679b5ca4a26a99 + with: + sarif_file: results.sarif diff --git a/.github/workflows/test-concordia.yml b/.github/workflows/test-concordia.yml new file mode 100644 index 00000000..cba8c84e --- /dev/null +++ b/.github/workflows/test-concordia.yml @@ -0,0 +1,53 @@ +name: test-concordia + +on: + push: + branches: + - main + paths: + - '.github/actions/install-concordia/action.yml' + - '.github/workflows/test-concordia.yml' + - 'concordia/**' + - 'pyproject.toml' + - 'setup.py' + pull_request: + branches: + - main + paths: + - '.github/actions/install-concordia/action.yml' + - '.github/workflows/test-concordia.yml' + - 'concordia/**' + - 'pyproject.toml' + - 'setup.py' + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +permissions: read-all + +jobs: + pytest: + name: Test Concordia + runs-on: ${{ matrix.os }} + env: + SYSTEM_VERSION_COMPAT: 0 # See https://github.com/actions/setup-python/issues/279. + timeout-minutes: 120 + strategy: + fail-fast: ${{ github.event_name != 'workflow_dispatch' }} + matrix: + os: + - macos-11 + - ubuntu-20.04 + python-version: + - '3.11' + steps: + - name: Checkout Concordia + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 + - name: Install Concordia + uses: ./.github/actions/install-concordia + with: + python-version: ${{ matrix.python-version }} + - name: Test Concordia + run: pytest concordia diff --git a/.github/workflows/test-examples.yml b/.github/workflows/test-examples.yml new file mode 100644 index 00000000..8eeb6383 --- /dev/null +++ b/.github/workflows/test-examples.yml @@ -0,0 +1,47 @@ +name: test-examples + +on: + push: + branches: + - main + paths: + - '.github/actions/install-examples/action.yml' + - '.github/actions/install-concordia/action.yml' + - '.github/workflows/test-examples.yml' + - 'examples/**' + - 'concordia/**' + - 'pyproject.toml' + - 'setup.py' + pull_request: + branches: + - main + paths: + - '.github/actions/install-examples.yml' + - '.github/actions/install-concordia.yml' + - '.github/workflows/test-examples.yml' + - 'examples/**' + - 'concordia/**' + - 'pyproject.toml' + - 'setup.py' + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +permissions: read-all + +jobs: + pytest: + name: Test examples + runs-on: ubuntu-latest + timeout-minutes: 90 + steps: + - name: Checkout Concordia + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 + + - name: Install examples + uses: ./.github/actions/install-examples + + - name: Test examples + run: pytest examples diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..807e6de3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,25 @@ +# MacOS metadata. +.DS_Store + +# Byte-compiled Python code. +*.py[cod] +__pycache__/ +.cache + +# Common venv names. +/*venv/ +/venv*/ + +# Ignore files created during installation and building. +*.egg-info +.eggs/ +/assets/ +/build/ +/dist/ +/lab2d/ + +# Test cache +.pytest_cache + +# Type checking +.pytype diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 00000000..5c771e62 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,399 @@ +# This Pylint rcfile contains a best-effort configuration to uphold the +# best-practices and style described in the Google Python style guide: +# https://google.github.io/styleguide/pyguide.html +# +# Its canonical open-source location is: +# https://google.github.io/styleguide/pylintrc + +[MAIN] + +# Files or directories to be skipped. They should be base names, not paths. +ignore=third_party + +# Files or directories matching the regex patterns are skipped. The regex +# matches against base names, not paths. +ignore-patterns= + +# Pickle collected data for later comparisons. +persistent=no + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=4 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +#enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=R, + abstract-method, + apply-builtin, + arguments-differ, + attribute-defined-outside-init, + backtick, + bad-option-value, + basestring-builtin, + buffer-builtin, + c-extension-no-member, + consider-using-enumerate, + cmp-builtin, + cmp-method, + coerce-builtin, + coerce-method, + delslice-method, + div-method, + eq-without-hash, + execfile-builtin, + file-builtin, + filter-builtin-not-iterating, + fixme, + getslice-method, + global-statement, + hex-method, + idiv-method, + implicit-str-concat, + import-error, + import-self, + import-star-module-level, + input-builtin, + intern-builtin, + invalid-str-codec, + locally-disabled, + long-builtin, + long-suffix, + map-builtin-not-iterating, + misplaced-comparison-constant, + missing-function-docstring, + metaclass-assignment, + next-method-called, + next-method-defined, + no-absolute-import, + no-init, # added + no-member, + no-name-in-module, + no-self-use, + nonzero-method, + oct-method, + old-division, + old-ne-operator, + old-octal-literal, + old-raise-syntax, + parameter-unpacking, + print-statement, + raising-string, + range-builtin-not-iterating, + raw_input-builtin, + rdiv-method, + reduce-builtin, + relative-import, + reload-builtin, + round-builtin, + setslice-method, + signature-differs, + standarderror-builtin, + suppressed-message, + sys-max-int, + trailing-newlines, + unichr-builtin, + unicode-builtin, + unnecessary-pass, + unpacking-in-except, + useless-else-on-loop, + useless-suppression, + using-cmp-argument, + wrong-import-order, + xrange-builtin, + zip-builtin-not-iterating, + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[BASIC] + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl + +# Regular expression matching correct function names +function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression matching correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct constant names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression matching correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression matching correct module names +module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ + +# Regular expression matching correct method names +method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=12 + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=80 + +# TODO(https://github.com/pylint-dev/pylint/issues/3352): Direct pylint to exempt +# lines made too long by directives to pytype. + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x)( + ^\s*(\#\ )??$| + ^\s*(from\s+\S+\s+)?import\s+.+$) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=yes + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. The internal Google style guide mandates 2 +# spaces. Google's externaly-published style guide says 4, consistent with +# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google +# projects (like TensorFlow). +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=TODO + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=yes + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging,absl.logging,tensorflow.io.logging + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub, + TERMIOS, + Bastion, + rexec, + sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant, absl + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls, + class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..ae336850 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,10 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](http://keepachangelog.com/) +and this project adheres to [Semantic Versioning](http://semver.org/). + +## [1.0.0] - 2023-12-07 + +Initial release. diff --git a/CITATION.bib b/CITATION.bib new file mode 100644 index 00000000..05d8544f --- /dev/null +++ b/CITATION.bib @@ -0,0 +1,10 @@ + +% TODO: b/311364310 - update CITATION.bib and README.md once tech report published +@inproceedings{vezhnevets2023concordia, + title={Concordia: a library for generative social simulation}, + author={Alexander Sasha Vezhnevets AND Joel Z. Leibo AND John P. Agapiou + AND Danny Karmon AND Avia Aharon AND Ron Viz + AND Jayd Matyas AND Edgar Du\'e\~nez-Guzm\'an AND Wil Cunnigham + AND Simon Osindero}, + year={2023}, +} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..faea7268 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,38 @@ +# How to contribute + +We'd love to accept your patches and contributions to this project. + +## Before you begin + +### Sign our Contributor License Agreement + +Contributions to this project must be accompanied by a +[Contributor License Agreement](https://cla.developers.google.com/about) (CLA). +You (or your employer) retain the copyright to your contribution; this simply +gives us permission to use and redistribute your contributions as part of the +project. + +If you or your current employer have already signed the Google CLA (even if it +was for a different project), you probably don't need to do it again. + +Visit to see your current agreements or to +sign a new one. + +### Review our community guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google/conduct/). + +\## Familiarize yourself with our code style + +This porject follows the +[Google style guide](https://google.github.io/styleguide/). + +## Contribution process + +### Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..6b0b1270 --- /dev/null +++ b/LICENSE @@ -0,0 +1,203 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + diff --git a/README.md b/README.md new file mode 100644 index 00000000..0be1360a --- /dev/null +++ b/README.md @@ -0,0 +1,103 @@ +# Concordia + +*A library for generative social simulation* + +[![Python](https://img.shields.io/pypi/pyversions/dm-concordia.svg)](https://pypi.python.org/pypi/dm-concordia) +[![PyPI version](https://img.shields.io/pypi/v/dm-concordia.svg)](https://pypi.python.org/pypi/dm-concordia) +[![PyPI tests](../../actions/workflows/pypi-test.yml/badge.svg)](../../actions/workflows/pypi-test.yml) +[![Tests](../../actions/workflows/test-concordia.yml/badge.svg)](../../actions/workflows/test-concordia.yml) +[![Examples](../../actions/workflows/test-examples.yml/badge.svg)](../../actions/workflows/test-examples.yml) + + +[Concordia Tech Report]() + +## About + +Concordia is a platform designed for constructing generative models that +simulate social interactions within a digitally-grounded action space. This +platform facilitates the emulation of agent behaviors and activities. The +framework can cater and support a wide array of applications, ranging from +social science research and AI ethics to cognitive neuroscience and economics; +Additionally, it also can be leveraged for generating data for personalization +applications and for conducting performance evaluations of real services through +simulated usage. Our system simply requires access to a standard LLM API, and +possible integration to real applications and services. The rest is python for +scaffolding, orchestration, prompt-templating, experiment design and analysis. + + +## Installation + +### `pip` install + +[Concordia is available on PyPI](https://pypi.python.org/pypi/gdm-concordia) +and can be installed using: + +```shell +pip install gdm-concordia +``` + + +### Manual install + +If you want to work on the Concordia source code, you can perform an editable +installation as follows: + +1. Clone Concordia: + + ```shell + git clone -b main https://github.com/google-deepmind/concordia + cd concordia + ``` + +2. Install Concordia: + + ```shell + pip install --editable .[dev] + ``` + +3. (Optional) Test the installation: + + ```shell + pytest --pyargs concordia + ``` + + +## Bring your own LLM + +To work, Concordia requires an access to an LLM API. The example below is +written using [Saxml](https://github.com/google/saxml), but any LLM API that +supports sampling text and calculating log-likelihood would work. We recommend +using large (>300B parameters) models. If using a custom LLM API, the user has +to provide a text embedder to be used by the associative memory. By default we +use the Sentence-T5 for this, but any fixed-dimensional embedding would work. + +## Example usage + +Find below an illustrative social simulation with 5 players which simulates the +day of mayoral elections in an imaginary town caller Riverbend. First two +players, Alice and Bob, are running for mayor. The third player, Charlie, +is trying to ruin Alice's reputation with disinformation. The last two players +have no specific agenda, apart from voting in the election. + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/concordia/examples/village/riverbend_elections.ipynb) + +## Citing Concordia + +If you use Concordia in your work, please cite the accompanying article: + + + +```bibtex +@inproceedings{vezhnevets2023concordia, + title={Concordia: a library for generative social simulation}, + author={Alexander Sasha Vezhnevets AND Joel Z. Leibo AND John P. Agapiou + AND Danny Karmon AND Avia Aharon AND Ron Viz + AND Jayd Matyas AND Edgar Du\'e\~nez-Guzm\'an AND Wil Cunnigham + AND Simon Osindero}, + year={2023}, +} +``` + +## Disclaimer + +This is not an officially supported Google product. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..a8c2236d --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,4 @@ +To report a security issue, please use [https://g.co/vulnz](https://g.co/vulnz). +We use g.co/vulnz for our intake, and do coordination and disclosure here on +GitHub (including using GitHub Security Advisory). The Google Security Team will +respond within 5 working days of your report on g.co/vulnz. diff --git a/concordia/__init__.py b/concordia/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/concordia/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/concordia/agents/__init__.py b/concordia/agents/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/concordia/agents/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/concordia/agents/basic_agent.py b/concordia/agents/basic_agent.py new file mode 100644 index 00000000..7201ace8 --- /dev/null +++ b/concordia/agents/basic_agent.py @@ -0,0 +1,275 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Classes to use in a basic generative agent. + +Based on: + +Park, J.S., O'Brien, J.C., Cai, C.J., Morris, M.R., Liang, P. and +Bernstein, M.S., 2023. Generative agents: Interactive simulacra of human +behavior. arXiv preprint arXiv:2304.03442. +""" + +import concurrent +import contextlib +import copy +import datetime + +from concordia.associative_memory import associative_memory +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import agent +from concordia.typing import clock as game_clock +from concordia.typing import component +from concordia.utils import helper_functions +from IPython import display +import termcolor + + +class BasicAgent( + agent.GenerativeAgent, + agent.SpeakerGenerativeAgent, +): + """A Generative agent.""" + + def __init__( + self, + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + agent_name: str, + clock: game_clock.GameClock, + components: list[component.Component] | None = None, + num_memories_retrieved: int = 10, + update_interval: datetime.timedelta = datetime.timedelta(hours=1), + verbose: bool = False, + user_controlled: bool = False, + print_colour='green', + ): + """A generative agent. + + Args: + model: a language model + memory: an associative memory + agent_name: the name of the agent + clock: the game clock is needed to know when is the current time + components: components that contextualise the policies + num_memories_retrieved: number of memories to retrieve for acting, + speaking, testing + update_interval: how often to update components. In game time according to + the clock argument. + verbose: whether to print chains of thought or not + user_controlled: if True, would query user input for speach and action + print_colour: which colour to use for printing + """ + self._verbose = verbose + self._print_colour = print_colour + + self._model = model + self._memory = memory + + self._agent_name = agent_name + self._clock = clock + self._num_memories_retrieved = num_memories_retrieved + self._user_controlled = user_controlled + self._update_interval = update_interval + + self._under_interrogation = False + + self._components = {} + for comp in components: + self.add_component(comp) + + self._log = [] + self._last_chain_of_thought = None + self._last_update = datetime.datetime.min + + @property + def name(self) -> str: + return self._agent_name + + def copy(self) -> 'BasicAgent': + """Creates a copy of the agent.""" + new_sim = BasicAgent( + model=self._model, + memory=self._memory, + agent_name=self._agent_name, + clock=self._clock, + components=copy.copy(list(self._components.values())), + num_memories_retrieved=self._num_memories_retrieved, + verbose=self._verbose, + user_controlled=self._user_controlled, + print_colour=self._print_colour, + ) + return new_sim + + def get_memory(self) -> associative_memory.AssociativeMemory: + return self._memory + + def _print(self, entry: str): + print(termcolor.colored(entry, self._print_colour), end='') + + def add_component(self, comp: component.Component) -> None: + """Add a component.""" + if comp.name() in self._components: + raise ValueError(f'Duplicate component name: {comp.name()}') + else: + self._components[comp.name()] = comp + + def remove_component(self, component_name: str) -> None: + """Remove a component.""" + del self._components[component_name] + + def set_clock(self, clock: game_clock.GameClock): + self._clock = clock + + def enter_interrogation(self): + self._under_interrogation = True + + def leave_interrogation(self): + self._under_interrogation = False + + @contextlib.contextmanager + def interrogate(self): + """Context manager to interrogate the agent. + + When in this context, agent makes no memories or observations and doesn't + update components. + + Yields: + None + """ + self.enter_interrogation() + try: + yield + finally: + self.leave_interrogation() + + def _ask_for_input(self, context: str, prompt: str) -> str: + display.clear_output() + print(context, flush=True) + result = input(prompt) + return result + + def get_last_log(self): + return self._last_chain_of_thought + + def state(self): + return '\n'.join( + f"{self._agent_name}'s " + (comp.name() + ':\n' + comp.state()) + for comp in self._components.values() + ) + + def _maybe_update(self): + next_update = self._last_update + self._update_interval + if self._clock.now() >= next_update and not self._under_interrogation: + self.update() + + def update(self): + self._last_update = self._clock.now() + with concurrent.futures.ThreadPoolExecutor() as executor: + for comp in self._components.values(): + executor.submit(comp.update) + + def observe(self, observation: str): + if observation and not self._under_interrogation: + for comp in self._components.values(): + comp.observe(observation) + + def act( + self, + action_spec: agent.ActionSpec = agent.DEFAULT_ACTION_SPEC, + memorize: bool = False, + ): + if not action_spec: + action_spec = agent.DEFAULT_ACTION_SPEC + self._maybe_update() + prompt = interactive_document.InteractiveDocument(self._model) + context_of_action = '\n'.join([ + f'{self.state()}', + ]) + + prompt.statement(context_of_action) + + call_to_action = action_spec.call_to_action.format( + agent_name=self._agent_name, + timedelta=helper_functions.timedelta_to_readable_str( + self._clock.get_step_size() + ), + ) + output = '' + + if action_spec.output_type == 'FREE': + if self._user_controlled: + output = self._ask_for_input( + context_of_action, + call_to_action + '\n', + ) + else: + output = prompt.open_question( + call_to_action, + max_characters=1200, + max_tokens=1200, + ) + elif action_spec.output_type == 'CHOICE': + idx = prompt.multiple_choice_question( + question=call_to_action, answers=action_spec.options + ) + output = action_spec.options[idx] + elif action_spec.output_type == 'FLOAT': + raise NotImplementedError + + self._last_chain_of_thought = prompt.view().text().splitlines() + + if self._verbose: + self._print( + f'\n{self._agent_name} context of action:\n' + + prompt.view().text() + + '\n' + ) + + if memorize and not self._under_interrogation: # observe instead? + if action_spec.tag: + self._memory.add( + f'[{action_spec.tag}] {output}', tags=[action_spec.tag] + ) + else: + self._memory.add(output) + + return output + + def add_memory(self, memory: str, importance: float | None = None): + self._memory.add(memory, importance=importance) + + def say(self, conversation: str) -> str: + convo_context = ( + f'{self._agent_name} is in the following' + f' conversation:\n{conversation}\n' + ) + call_to_speach = ( + f'Given the above, what should {self._agent_name} say next? Respond in' + f' the format `{self._agent_name} says: "..."` For example, ' + 'Cristina says: "Hello! Mighty fine weather today, right?" ' + 'or Ichabod says: "I wonder if the alfalfa is ready to harvest.\n' + ) + if self._user_controlled: + utterance = self._ask_for_input( + convo_context + call_to_speach, f'{self._agent_name}:' + ) + else: + utterance = self.act( + action_spec=agent.ActionSpec(convo_context + call_to_speach, 'FREE'), + ) + + return utterance diff --git a/concordia/agents/components/__init__.py b/concordia/agents/components/__init__.py new file mode 100644 index 00000000..976554bf --- /dev/null +++ b/concordia/agents/components/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Library of components for generative game master and agents.""" + +from concordia.agents.components import characteristic +from concordia.agents.components import constant +from concordia.agents.components import identity +from concordia.agents.components import observation +from concordia.agents.components import person_by_situation +from concordia.agents.components import plan +from concordia.agents.components import reflection +from concordia.agents.components import report_state +from concordia.agents.components import self_perception +from concordia.agents.components import sequential +from concordia.agents.components import situation_perception +from concordia.agents.components import somatic_state diff --git a/concordia/agents/components/characteristic.py b/concordia/agents/components/characteristic.py new file mode 100644 index 00000000..85415da8 --- /dev/null +++ b/concordia/agents/components/characteristic.py @@ -0,0 +1,119 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Agent characteristic component.""" + +from concordia.associative_memory import associative_memory +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import clock as game_clock +from concordia.typing import component +import termcolor + + +class Characteristic(component.Component): + """Implements a simple characteristic component. + + For example, "current daily occupation", "core characteristic" or "hunger". + The component queries the memory for the agent's haracteristic and then + summarises it. + + In psychology it is common to distinguish between `state` characteristics and + `trait` characteristics. A `state` is temporary, like being hungry or afraid, + but a `trait` endures over a long period of time, e.g. being neurotic or + extroverted. + + When the characteristic is a `state` (as opposed to a `trait`) then time is + used in the query for memory retrieval and the instruction for summarization. + When the characteristic is a `trait` then time is not used. + + When you pass a `state_clock` while creating a characteristic then you create + a `state` characteristic. When you do not pass a `state_clock` then you create + a `trait` characteristic. + """ + + def __init__( + self, + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + agent_name: str, + characteristic_name: str, + state_clock: game_clock.GameClock | None = None, + extra_instructions: str = '', + num_memories_to_retrieve: int = 25, + verbose: bool = False, + ): + """Represents a characteristic of an agent (a trait or a state). + + Args: + model: a language model + memory: an associative memory + agent_name: the name of the agent + characteristic_name: the string to use in similarity search of memory + state_clock: if None then consider this component as representing a + `trait`. If a clock is used then consider this component to represent a + `state`. A state is temporary whereas a trait is meant to endure. + extra_instructions: append additional instructions when asking the model + to assess the characteristic. + num_memories_to_retrieve: how many memories to retrieve during the update + verbose: whether or not to print intermediate reasoning steps. + """ + self._verbose = verbose + self._model = model + self._memory = memory + self._cache = '' + self._characteristic_name = characteristic_name + self._agent_name = agent_name + self._extra_instructions = extra_instructions + self._clock = state_clock + self._num_memories_to_retrieve = num_memories_to_retrieve + + def name(self) -> str: + return self._characteristic_name + + def state(self) -> str: + return self._cache + + def update(self) -> None: + query = f"{self._agent_name}'s {self._characteristic_name}" + if self._clock is not None: + query = f'[{self._clock.now()}] {query}' + + mems = '\n'.join( + self._memory.retrieve_associative(query, + self._num_memories_to_retrieve, + add_time=True) + ) + + prompt = interactive_document.InteractiveDocument(self._model) + + question = ( + f"How would one describe {self._agent_name}'s" + f' {self._characteristic_name} given the following statements? ' + f'{self._extra_instructions}' + f'Start the answer with "{self._agent_name} is"' + ) + if self._clock is not None: + question = f'Current time: {self._clock.now()}.\n{question}' + + self._cache = prompt.open_question( + '\n'.join([question, f'Statements:\n{mems}']), + max_characters=3000, + max_tokens=1000, + ) + + self._last_chain = prompt + if self._verbose: + print(termcolor.colored(self._last_chain.view().text(), 'red'), end='') diff --git a/concordia/agents/components/constant.py b/concordia/agents/components/constant.py new file mode 100644 index 00000000..ed0bb032 --- /dev/null +++ b/concordia/agents/components/constant.py @@ -0,0 +1,48 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""This component always returns the same string.""" + +from concordia.typing import component + + +class ConstantConstruct(component.Component): + """A constant memory component.""" + + def __init__(self, state: str, name: str = 'constant'): + """Initializes the constant component. + + Args: + state: The state of the memory component. + name: The name of the memory component. + """ + self._state = state + self._name = name + + def name(self) -> str: + """Returns the name of the memory component.""" + return self._name + + def state(self) -> str: + """Returns the state of the memory component.""" + return self._state + + def update(self) -> None: + """This component always returns the same string, update does nothing.""" + pass + + def set_state(self, state: str) -> None: + """Set the constant state.""" + self._state = state diff --git a/concordia/agents/components/identity.py b/concordia/agents/components/identity.py new file mode 100644 index 00000000..3fd2767a --- /dev/null +++ b/concordia/agents/components/identity.py @@ -0,0 +1,83 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Agent identity component.""" + +import concurrent +from concordia.agents.components import characteristic +from concordia.associative_memory import associative_memory +from concordia.language_model import language_model +from concordia.typing import component + + +class SimIdentity(component.Component): + """Identity component containing a few characteristics. + + Identity is built out of 3 characteristics: + 1. 'core characteristics', + 2. 'current daily occupation', + 3. 'feeling about recent progress in life', + """ + + def __init__( + self, + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + agent_name: str, + ): + """Initialize an identity component. + + Args: + model: a language model + memory: an associative memory + agent_name: the name of the agent + """ + self._model = model + self._memory = memory + self._state = '' + self._agent_name = agent_name + + self._identity_component_names = [ + 'core characteristics', + 'current daily occupation', + 'feeling about recent progress in life', + ] + + self._identity_components = [] + + for component_name in self._identity_component_names: + self._identity_components.append( + characteristic.Characteristic( + model=model, + memory=self._memory, + agent_name=self._agent_name, + characteristic_name=component_name, + ) + ) + + def name(self) -> str: + return 'Identity' + + def state(self): + return self._state + + def update(self): + with concurrent.futures.ThreadPoolExecutor() as executor: + for c in self._identity_components: + executor.submit(c.update) + + self._state = f'Name: {self._agent_name}\n' + '\n'.join( + [c.state() for c in self._identity_components] + ) diff --git a/concordia/agents/components/observation.py b/concordia/agents/components/observation.py new file mode 100644 index 00000000..95e515e8 --- /dev/null +++ b/concordia/agents/components/observation.py @@ -0,0 +1,148 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Agent components for representing observation stream.""" + +from concordia.associative_memory import associative_memory +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import component +import termcolor + + +class Observation(component.Component): + """Component that stacks current observations together, clears on update.""" + + def __init__( + self, + agent_name: str, + memory: associative_memory.AssociativeMemory, + component_name: str = 'Current observation', + verbose: bool = False, + log_colour='green', + ): + """Initialize the observation component. + + Args: + agent_name: the name of the agent + memory: memory for writing observations into + component_name: the name of this component + verbose: whether or not to print intermediate reasoning steps + log_colour: colour for logging + """ + self._agent_name = agent_name + self._log_colour = log_colour + self._name = component_name + self._memory = memory + + self._last_observation = [] + + self._verbose = verbose + + def name(self) -> str: + return self._name + + def state(self): + if self._verbose: + self._log('\n'.join(self._last_observation) + '\n') + return '\n'.join(self._last_observation) + '\n' + + def _log(self, entry: str): + print(termcolor.colored(entry, self._log_colour), end='') + + def observe(self, observation: str): + self._last_observation.append(observation) + self._memory.add( + f'[observation] {observation}', + tags=['observation'], + ) + + def update(self): + self._last_observation = [] + return '' + + +class ObservationSummary(component.Component): + """Component that summarises current observations on update.""" + + def __init__( + self, + model: language_model.LanguageModel, + agent_name: str, + components: list[component.Component], + verbose: bool = False, + log_colour='green', + ): + """Initialize a construct containing the agent's plan for the day. + + Args: + model: a language model + agent_name: the name of the agent + components: components to condition observation summarisation + verbose: whether or not to print intermediate reasoning steps + log_colour: colour for logging + """ + self._model = model + self._state = '' + self._agent_name = agent_name + self._log_colour = log_colour + self._components = components + + self._last_observation = [] + + self._verbose = verbose + + def name(self) -> str: + return 'Summary of recent observations' + + def state(self): + return self._state + + def _log(self, entry: str): + print(termcolor.colored(entry, self._log_colour), end='') + + def observe(self, observation: str): + self._last_observation.append(observation) + + def update(self): + context = '\n'.join( + [ + f"{self._agent_name}'s " + + (construct.name() + ':\n' + construct.state()) + for construct in self._components + ] + ) + + numbered_observations = [ + f'{i}. {observation}' + for i, observation in enumerate(self._last_observation) + ] + current_observations = '\n'.join(numbered_observations) + + prompt = interactive_document.InteractiveDocument(self._model) + prompt.statement(context + '\n') + prompt.statement( + 'Current observations, numbered in chronological order:\n' + + f'{current_observations}\n' + ) + self._state = prompt.open_question( + 'Summarize the observations into one sentence.' + ) + + self._last_observation = [] + + if self._verbose: + self._log('\nObservation summary:') + self._log('\n' + prompt.view().text() + '\n') diff --git a/concordia/agents/components/person_by_situation.py b/concordia/agents/components/person_by_situation.py new file mode 100644 index 00000000..19dd7d05 --- /dev/null +++ b/concordia/agents/components/person_by_situation.py @@ -0,0 +1,107 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent component for self perception.""" + +from typing import Sequence +from concordia.associative_memory import associative_memory +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import clock +from concordia.typing import component +import termcolor + + +class PersonBySituation(component.Component): + """What would a person like the agent do in a situation like this?""" + + def __init__( + self, + name: str, + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + agent_name: str, + components=Sequence[component.Component] | None, + state_clock: clock.GameClock | None = None, + num_memories_to_retrieve: int = 25, + verbose: bool = False, + ): + """Initializes the PersonBySituation component. + + Args: + name: The name of the component. + model: The language model to use. + memory: The memory to use. + agent_name: The name of the agent. + components: The components to condition the answer on. + state_clock: The clock to use. + num_memories_to_retrieve: The number of memories to retrieve. + verbose: Whether to print the state of the component. + """ + + self._verbose = verbose + self._model = model + self._memory = memory + self._state = '' + self._components = components or [] + self._agent_name = agent_name + self._clock = state_clock + self._num_memories_to_retrieve = num_memories_to_retrieve + self._name = name + + def name(self) -> str: + return self._name + + def state(self) -> str: + return self._state + + def update(self) -> None: + prompt = interactive_document.InteractiveDocument(self._model) + + mems = '\n'.join( + self._memory.retrieve_recent( + self._num_memories_to_retrieve, add_time=True + ) + ) + + prompt.statement(f'Memories of {self._agent_name}:\n{mems}') + + component_states = '\n'.join( + [ + f"{self._agent_name}'s " + + (construct.name() + ':\n' + construct.state()) + for construct in self._components + ] + ) + + prompt.statement(component_states) + question = ( + f'What would a person like {self._agent_name} do in a situation like' + ' this?' + ) + if self._clock is not None: + question = f'Current time: {self._clock.now()}.\n{question}' + + self._state = prompt.open_question( + question, + answer_prefix=f'{self._agent_name} would ', + max_characters=3000, + max_tokens=1000, + ) + + self._state = f'{self._agent_name} would {self._state}' + + self._last_chain = prompt + if self._verbose: + print(termcolor.colored(self._last_chain.view().text(), 'red'), end='') diff --git a/concordia/agents/components/plan.py b/concordia/agents/components/plan.py new file mode 100644 index 00000000..217ddb8d --- /dev/null +++ b/concordia/agents/components/plan.py @@ -0,0 +1,145 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Agent components for planning.""" + +from concordia.associative_memory import associative_memory +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import component +import termcolor + + +class SimPlan(component.Component): + """Component representing the agent's plan.""" + + def __init__( + self, + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + agent_name: str, + components: list[component.Component], + goal: component.Component | None = None, + num_memories_to_retrieve: int = 5, + timescale: str = 'the rest of the day', + time_adverb: str = 'hourly', + verbose: bool = False, + log_colour='green', + ): + """Initialize a component to represent the agent's plan. + + Args: + model: a language model + memory: an associative memory + agent_name: the name of the agent + components: components to build the context of planning + goal: a component to represent the goal of planning + num_memories_to_retrieve: how many memories to retrieve as conditioning + for the planning chain of thought + timescale: string describing how long the plan should last + time_adverb: string describing the rate of steps in the plan + verbose: whether or not to print intermediate reasoning steps + log_colour: colour for logging + """ + self._model = model + self._memory = memory + self._state = '' + self._agent_name = agent_name + self._log_colour = log_colour + self._components = components + self._num_memories_to_retrieve = num_memories_to_retrieve + self._goal_component = goal + self._timescale = timescale + self._time_adverb = time_adverb + + self._latest_memories = '' + self._last_observation = [] + self._current_plan = '' + + self._verbose = verbose + + def name(self) -> str: + return 'Plan' + + def state(self): + return self._state + + def _log(self, entry: str): + print(termcolor.colored(entry, self._log_colour), end='') + + def observe(self, observation: str): + self._last_observation.append(observation) + + def update(self, push_to_mem=True): + observation = '\n'.join(self._last_observation) + self._last_observation = [] + memories = self._memory.retrieve_associative( + observation, + k=self._num_memories_to_retrieve, + use_recency=True, + add_time=True, + ) + if self._goal_component: + memories = memories + self._memory.retrieve_associative( + self._goal_component.state(), + k=self._num_memories_to_retrieve, + use_recency=True, + add_time=True, + ) + memories = '\n'.join(memories) + + components = '\n'.join( + [ + f"{self._agent_name}'s " + + (construct.name() + ':\n' + construct.state()) + for construct in self._components + ] + ) + + prompt = interactive_document.InteractiveDocument(self._model) + prompt.statement(f'{components}\n') + prompt.statement(f'Relevant memories:\n{memories}') + if self._goal_component: + prompt.statement(f'Current goal: {self._goal_component.state()}.') + prompt.statement(f'Current plan: {self._current_plan}') + prompt.statement(f'Current situation: {observation}') + should_replan = prompt.yes_no_question( + ( + f'Given the above, should {self._agent_name} change their current ' + + 'plan?' + ) + ) + + if should_replan or not self._state: + goal_mention = '.' + if self._goal_component: + goal_mention = ', keep in mind the goal.' + self._current_plan = prompt.open_question( + f"What is {self._agent_name}'s plan for {self._timescale}? Please," + f' provide a {self._time_adverb} schedule' + + goal_mention, + max_characters=1200, + max_tokens=1200, + terminators=(), + ) + if self._goal_component: + self._state = ( + f'The goal: {self._goal_component.state()}\n{self._current_plan}' + ) + else: + self._state = self._current_plan + + if self._verbose: + self._log('\n' + prompt.view().text() + '\n') diff --git a/concordia/agents/components/reflection.py b/concordia/agents/components/reflection.py new file mode 100644 index 00000000..dfba6a0d --- /dev/null +++ b/concordia/agents/components/reflection.py @@ -0,0 +1,115 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Agent characteristic component.""" + +from concordia.associative_memory import associative_memory +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import component +import termcolor + + +class Reflection(component.Component): + """Implements a reflection component. + + First, the last 100 memories are retrieved, using which the 3 most salient + questions are inferred. These questions are used as a query. The output of + the query is summarised into 5 insights. + """ + + def __init__( + self, + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + agent_name: str, + name: str = 'reflection', + importance_threshold: float = 20.0, + verbose: bool = False, + ): + self._model = model + self._memory = memory + self._state = '' + self._agent_name = agent_name + self._verbose = verbose + self._name = name + self._importance_threshold = importance_threshold + + def name(self) -> str: + return self._name + + def state(self) -> str: + return self._state + + def update(self) -> None: + mems, importance = self._memory.retrieve_recent_with_importance( + 100, add_time=True + ) + total_importance = sum(importance) + if total_importance < self._importance_threshold: + self._state = '' + if self._verbose: + print( + termcolor.colored( + f'Importance {total_importance} below threshold', 'green' + ), + end='', + ) + + return + + mems = '\n'.join(mems) + + prompt = interactive_document.InteractiveDocument(self._model) + + questions = prompt.open_question( + '\n'.join([ + f'{mems}', + ( + 'Given only the statements above, what are' + ' the 3 most salient high-level questions we can' + f' answer about {self._agent_name}?' + ), + ]), + max_characters=5000, + max_tokens=5000, + terminators=(), + ) + + mems = [] + # make sure that the answer comes out of LLM in the right format + for question in questions.splitlines(): + mems += self._memory.retrieve_associative(question, 10, add_time=True) + + mems = '\n'.join(mems) + + prompt = interactive_document.InteractiveDocument(self._model) + + self._state = prompt.open_question( + '\n'.join([ + f'{mems}', + ( + 'What 5 high-level insights can you infer from the above' + ' statements?' + ), + ]), + max_characters=5000, + max_tokens=5000, + terminators=(), + ) + self._memory.extend(self._state.splitlines()) + self._last_chain = prompt + if self._verbose: + print(termcolor.colored(self._last_chain.view().text(), 'green'), end='') diff --git a/concordia/agents/components/report_state.py b/concordia/agents/components/report_state.py new file mode 100644 index 00000000..6e153474 --- /dev/null +++ b/concordia/agents/components/report_state.py @@ -0,0 +1,52 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""This components report what the get_state returns at the moment. + +For example, can be used for reporting current time +current_time_component = ReportState( + 'Current time', + get_state=clock.current_time_interval_str) + +""" + +from typing import Callable +from concordia.typing import component + + +class ReportState(component.Component): + """A component that shows the current time interval.""" + + def __init__(self, get_state: Callable[[], str], name: str = 'State'): + """Initializes the component. + + Args: + get_state: the game clock. + name: The name of the component. + """ + self._get_state = get_state + self._name = name + + def name(self) -> str: + """Returns the name of the component.""" + return self._name + + def state(self) -> str: + """Returns the state of the component.""" + return self._get_state() + + def update(self) -> None: + """This component always returns the same string, update does nothing.""" + pass diff --git a/concordia/agents/components/self_perception.py b/concordia/agents/components/self_perception.py new file mode 100644 index 00000000..8d053633 --- /dev/null +++ b/concordia/agents/components/self_perception.py @@ -0,0 +1,93 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent component for self perception.""" + +from concordia.associative_memory import associative_memory +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import clock +from concordia.typing import component +import termcolor + + +class SelfPerception(component.Component): + """This component answers the question 'what kind of person is the agent?'.""" + + def __init__( + self, + name: str, + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + agent_name: str, + state_clock: clock.GameClock | None = None, + num_memories_to_retrieve: int = 100, + verbose: bool = False, + ): + """Initializes the SelfPerception component. + + Args: + name: Name of the component. + model: Language model. + memory: Associative memory. + agent_name: Name of the agent. + state_clock: Clock to use for the state. + num_memories_to_retrieve: Number of memories to retrieve. + verbose: Whether to print the state. + """ + + self._verbose = verbose + self._model = model + self._memory = memory + self._state = '' + self._agent_name = agent_name + self._clock = state_clock + self._num_memories_to_retrieve = num_memories_to_retrieve + self._name = name + + def name(self) -> str: + return self._name + + def state(self) -> str: + return self._state + + def update(self) -> None: + mems = '\n'.join( + self._memory.retrieve_recent( + self._num_memories_to_retrieve, add_time=True + ) + ) + + prompt = interactive_document.InteractiveDocument(self._model) + prompt.statement(f'Memories of {self._agent_name}:\n{mems}') + + if self._clock is not None: + prompt.statement(f'Current time: {self._clock.now()}.\n') + + question = ( + f'Given the memories above, what kind of person is {self._agent_name}?' + ) + + self._state = prompt.open_question( + question, + answer_prefix=f'{self._agent_name} is ', + max_characters=3000, + max_tokens=1000, + ) + + self._state = f'{self._agent_name} is {self._state}' + + self._last_chain = prompt + if self._verbose: + print(termcolor.colored(self._last_chain.view().text(), 'green'), end='') diff --git a/concordia/agents/components/sequential.py b/concordia/agents/components/sequential.py new file mode 100644 index 00000000..7dc700ce --- /dev/null +++ b/concordia/agents/components/sequential.py @@ -0,0 +1,72 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Component that chain components in a sequential way, removing concurrency.""" + +from typing import Sequence + +from concordia.typing import component + + +class Sequential(component.Component): + """Chains components, removing concurrency.""" + + def __init__(self, name: str, components: Sequence[component.Component]): + self._components = components + self._name = name + + def update(self) -> None: + for comp in self._components: + comp.update() + + def state(self) -> str: + return '\n'.join( + [comp.name() + ': ' + comp.state() for comp in self._components] + ) + + def partial_state(self, player_name: str) -> str | None: + return '\n'.join( + [comp.partial_state(player_name) for comp in self._components] + ) + + def observe(self, observation: str): + for comp in self._components: + comp.observe(observation) + + def update_before_event(self, event_statement: str) -> None: + for comp in self._components: + comp.update_before_event(event_statement) + + def update_after_event(self, event_statement: str) -> None: + for comp in self._components: + comp.update_after_event(event_statement) + + def terminate_episode(self) -> bool: + for comp in self._components: + if comp.terminate_episode(): + return True + return False + + def name(self) -> str: + return self._name + + def get_last_log( + self, + ): + """Returns a dictionary with latest log of activity.""" + output = {} + for comp in self._components: + output[comp.name()] = comp.get_last_log() + + return output diff --git a/concordia/agents/components/situation_perception.py b/concordia/agents/components/situation_perception.py new file mode 100644 index 00000000..93111de9 --- /dev/null +++ b/concordia/agents/components/situation_perception.py @@ -0,0 +1,92 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent component for situation perception.""" + +from concordia.associative_memory import associative_memory +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import clock +from concordia.typing import component +import termcolor + + +class SituationPerception(component.Component): + """This component answers the question 'what kind of situation is it?'.""" + + def __init__( + self, + name: str, + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + agent_name: str, + state_clock: clock.GameClock | None = None, + num_memories_to_retrieve: int = 25, + verbose: bool = False, + ): + """Initializes the component. + + Args: + name: The name of the component. + model: The language model to use. + memory: The memory to use. + agent_name: The name of the agent. + state_clock: The clock to use. + num_memories_to_retrieve: The number of memories to retrieve. + verbose: Whether to print the last chain. + """ + self._verbose = verbose + self._model = model + self._memory = memory + self._state = '' + self._agent_name = agent_name + self._clock = state_clock + self._num_memories_to_retrieve = num_memories_to_retrieve + self._name = name + + def name(self) -> str: + return self._name + + def state(self) -> str: + return self._state + + def update(self) -> None: + mems = '\n'.join( + self._memory.retrieve_recent( + self._num_memories_to_retrieve, add_time=True + ) + ) + + prompt = interactive_document.InteractiveDocument(self._model) + prompt.statement(f'Memories of {self._agent_name}:\n{mems}') + + if self._clock is not None: + prompt.statement(f'Current time: {self._clock.now()}.\n') + + question = ( + 'Given the memories above, what kind of situation is' + f' {self._agent_name} in?' + ) + + self._state = prompt.open_question( + question, + answer_prefix=f'{self._agent_name} is currently ', + max_characters=3000, + max_tokens=1000, + ) + self._state = f'{self._agent_name} is currently {self._state}' + + self._last_chain = prompt + if self._verbose: + print(termcolor.colored(self._last_chain.view().text(), 'green'), end='') diff --git a/concordia/agents/components/somatic_state.py b/concordia/agents/components/somatic_state.py new file mode 100644 index 00000000..f994cdf5 --- /dev/null +++ b/concordia/agents/components/somatic_state.py @@ -0,0 +1,116 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Agent component for tracking the somatic state.""" + +import concurrent +from concordia.agents.components import characteristic +from concordia.associative_memory import associative_memory +from concordia.language_model import language_model +from concordia.typing import clock as game_clock +from concordia.typing import component + + +class SomaticState(component.Component): + """Somatic state component containing a five characteristics. + + Somatic state is comprised of hunger, thirst, fatigue, pain and feeling + socially connected to life. + """ + + def __init__( + self, + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + agent_name: str, + clock: game_clock.GameClock, + summarize: bool = True, + ): + """Initialize somatic state component. + + Args: + model: a language model + memory: an associative memory + agent_name: the name of the agent + clock: the game clock is needed to know when is the current time + summarize: if True, the resulting state will be a one sentence summary, + otherwise state it would be a concatentation of five separate + characteristics + """ + self._model = model + self._memory = memory + self._state = '' + self._agent_name = agent_name + self._clock = clock + self._summarize = summarize + + self._characteristic_names = [ + 'level of hunger', + 'level of thirst', + 'level of fatigue', + 'level of pain', + 'level of feeling socially connected in life', + ] + + self._characteristics = [] + + extra_instructions = ( + 'Be literal. Do not use any metaphorical language. ' + + 'When there is insufficient evidence to infer a ' + + 'specific answer then guess the most likely one. ' + + 'Never express uncertainty unless ' + + f'{self._agent_name} would be uncertain.' + ) + + for characteristic_name in self._characteristic_names: + self._characteristics.append( + characteristic.Characteristic( + model=model, + memory=self._memory, + agent_name=self._agent_name, + characteristic_name=characteristic_name, + state_clock=self._clock, + extra_instructions=extra_instructions, + ) + ) + + def name(self) -> str: + return 'Somatic state' + + def state(self): + return self._state + + def update(self): + with concurrent.futures.ThreadPoolExecutor() as executor: + for c in self._characteristics: + executor.submit(c.update) + + self._state = '\n'.join( + [ + f"{self._agent_name}'s {c.name()}: " + c.state() + for c in self._characteristics + ] + ) + if self._summarize: + prompt = ( + f'Summarize the somatic state of {self._agent_name} in one' + ' sentence given the readings below. Only mention readings that' + f' deviate from the norm, for example if {self._agent_name} is not' + ' hungry do not mention hunger at all.\nReadings:\n' + + self._state + ) + self._state = f'{self._agent_name} is ' + self._model.sample_text( + f'{prompt}\n {self._agent_name} is ', max_tokens=500 + ) diff --git a/concordia/associative_memory/__init__.py b/concordia/associative_memory/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/concordia/associative_memory/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/concordia/associative_memory/associative_memory.py b/concordia/associative_memory/associative_memory.py new file mode 100644 index 00000000..71aeb6b0 --- /dev/null +++ b/concordia/associative_memory/associative_memory.py @@ -0,0 +1,287 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""An associative memory similar to the one in the following paper. + +Park, J.S., O'Brien, J.C., Cai, C.J., Morris, M.R., Liang, P. and Bernstein, +M.S., 2023. Generative agents: Interactive simulacra of human behavior. arXiv +preprint arXiv:2304.03442. +""" +from collections.abc import Callable, Iterable +import datetime + +import numpy as np +import pandas as pd +import rwlock + + +class AssociativeMemory: + """Class that implements associative memory.""" + + def __init__( + self, + sentence_embedder: Callable[[str], np.ndarray], + importance: Callable[[str], float], + clock: Callable[[], datetime.datetime] = datetime.datetime.now, + clock_step_size: datetime.timedelta | None = None, + ): + """Constructor. + + Args: + sentence_embedder: text embedding model + importance: maps a sentence into [0,1] scale of importance + clock: a callable to get time when adding memories + clock_step_size: sets the step size of the clock. If None, assumes precise + time + """ + self._memory_bank_lock = rwlock.ReadWriteLock() + self._embedder = sentence_embedder + self._importance = importance + + self._memory_bank = pd.DataFrame( + columns=['text', 'time', 'tags', 'embedding', 'importance'] + ) + self._clock_now = clock + self._interval = clock_step_size + + def add( + self, + text: str, + *, + timestamp: datetime.datetime | None = None, + tags: list[str] | None = None, + importance: float | None = None, + ): + """Adds the text to the memory. + + Args: + text: what goes into the memory + timestamp: the time of the memory + tags: optional tags + importance: optionally set the importance of the memory. + """ + + embedding = self._embedder(text) + if importance is None: + importance = self._importance(text) + + if timestamp is None: + timestamp = self._clock_now() + + new_df = ( + pd.Series({ + 'text': text, + 'time': timestamp, + 'tags': tags, + 'embedding': embedding, + 'importance': importance, + }) + .to_frame() + .T + ) + + with self._memory_bank_lock.AcquireWrite(): + self._memory_bank = pd.concat( + [self._memory_bank, new_df], ignore_index=True + ) + + def extend( + self, + texts: Iterable[str], + **kwargs, + ): + """Adds the texts to the memory. + + Args: + texts: list of strings to add to the memory + **kwargs: arguments to pass on to .add + """ + for text in texts: + self.add(text, **kwargs) + + def get_data_frame(self): + with self._memory_bank_lock.AcquireRead(): + return self._memory_bank.copy() + + def _get_top_k_cosine(self, x: np.ndarray, k: int): + """Returns the top k rows of a dataframe that have the highest cosine similarity to an input vector x. + + Args: + x: The input vector. + k: The number of rows to return. + + Returns: + Rows, sorted by cosine similarity in descending order. + """ + with self._memory_bank_lock.AcquireRead(): + cosine_similarities = self._memory_bank['embedding'].apply( + lambda y: np.dot(x, y) + ) + + # Sort the cosine similarities in descending order. + cosine_similarities.sort_values(ascending=False, inplace=True) + + # Return the top k rows. + return self._memory_bank.iloc[cosine_similarities.head(k).index] + + def _get_top_k_similar_rows( + self, x, k: int, use_recency: bool = True, use_importance: bool = True + ): + """Returns the top k rows of a dataframe that have the highest cosine similarity to an input vector x. + + Args: + x: The input vector. + k: The number of rows to return. + use_recency: if true then weight similarity by recency + use_importance: if true then weight similarity by importance + + Returns: + Rows, sorted by cosine similarity in descending order. + """ + with self._memory_bank_lock.AcquireRead(): + cosine_similarities = self._memory_bank['embedding'].apply( + lambda y: np.dot(x, y) + ) + + similarity_score = cosine_similarities + + if use_recency: + max_time = self._memory_bank['time'].max() + discounted_time = self._memory_bank['time'].apply( + lambda y: 0.99 ** ((max_time - y) / datetime.timedelta(minutes=1)) + ) + similarity_score += discounted_time + + if use_importance: + importance = self._memory_bank['importance'] + similarity_score += importance + + # Sort the similarities in descending order. + similarity_score.sort_values(ascending=False, inplace=True) + + # Return the top k rows. + return self._memory_bank.iloc[similarity_score.head(k).index] + + def _get_k_recent(self, k: int): + with self._memory_bank_lock.AcquireRead(): + recency = self._memory_bank['time'].sort_values(ascending=False) + return self._memory_bank.iloc[recency.head(k).index] + + def _pd_to_text( + self, + data: pd.DataFrame, + add_time: bool = False, + sort_by_time: bool = True, + ): + """Formats a dataframe into list of strings. + + Args: + data: the dataframe to process + add_time: whether to add time + sort_by_time: whether to sort by time + + Returns: + A list of strings, one for each memory + """ + if sort_by_time: + data = data.sort_values('time', ascending=True) + + if add_time and not data.empty: + if self._interval: + this_time = data['time'] + next_time = data['time'] + self._interval + + interval = this_time.dt.strftime( + '%d %b %Y [%H:%M:%S ' + ) + next_time.dt.strftime('- %H:%M:%S]: ') + output = interval + data['text'] + else: + output = data['time'].dt.strftime('[%d %b %Y %H:%M:%S] ') + data['text'] + else: + output = data['text'] + + return output.tolist() + + def retrieve_associative( + self, + query: str, + k: int = 1, + use_recency: bool = True, + use_importance: bool = True, + add_time: bool = True, + sort_by_time: bool = True, + ): + """Retrieve memories associatively. + + Args: + query: a string to use for retrieval + k: how many memories to retrieve + use_recency: whether to use timestamps to weight by recency or not + use_importance: whether to use importance for retrieval + add_time: whether to add time stamp to the output + sort_by_time: whether to sort the result by time + + Returns: + List of strings corresponding to memories + """ + query_embedding = self._embedder(query) + + data = self._get_top_k_similar_rows( + query_embedding, + k, + use_recency=use_recency, + use_importance=use_importance, + ) + + return self._pd_to_text(data, add_time=add_time, sort_by_time=sort_by_time) + + def retrieve_recent( + self, + k: int = 1, + add_time: bool = False, + ): + """Retrieve memories by recency. + + Args: + k: number of entries to retrieve + add_time: whether to add time stamp to the output + + Returns: + List of strings corresponding to memories + """ + data = self._get_k_recent(k) + + return self._pd_to_text(data, add_time=add_time, sort_by_time=True) + + def retrieve_recent_with_importance( + self, + k: int = 1, + add_time: bool = False, + ): + """Retrieve memories by recency and return importance alongside. + + Args: + k: number of entries to retrieve + add_time: whether to add time stamp to the output + + Returns: + List of strings corresponding to memories + """ + data = self._get_k_recent(k) + + return ( + self._pd_to_text(data, add_time=add_time, sort_by_time=True), + list(data['importance']), + ) diff --git a/concordia/associative_memory/blank_memories.py b/concordia/associative_memory/blank_memories.py new file mode 100644 index 00000000..b8099710 --- /dev/null +++ b/concordia/associative_memory/blank_memories.py @@ -0,0 +1,56 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""This is a factory for generating memories for generative agents.""" + +from collections.abc import Callable +import datetime + +from concordia.associative_memory import associative_memory +from concordia.language_model import language_model +import numpy as np + + +class MemoryFactory: + """Generator of formative memories.""" + + def __init__( + self, + model: language_model.LanguageModel, + embedder: Callable[[str], np.ndarray], + importance: Callable[[str], float], + clock_now: Callable[[], datetime.datetime], + ): + self._model = model + self._embedder = embedder + self._importance = importance + self._clock_now = clock_now + + def make_blank_memory( + self, + ) -> associative_memory.AssociativeMemory: + """Creates a blank memory. + + Returns a blank memory + + Returns: + An empty memory structure + """ + + return associative_memory.AssociativeMemory( + self._embedder, + self._importance, + clock=self._clock_now, + ) diff --git a/concordia/associative_memory/embedder_st5.py b/concordia/associative_memory/embedder_st5.py new file mode 100644 index 00000000..35c81444 --- /dev/null +++ b/concordia/associative_memory/embedder_st5.py @@ -0,0 +1,41 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""English sentence embedding class using ST5. + +Ni, J., Ábrego, G.H., Constant, N., Ma, J., Hall, K.B., Cer, D. and Yang, Y., +2021. Sentence-t5: Scalable sentence encoders from pre-trained text-to-text +models. arXiv preprint arXiv:2108.08877. +""" + +from collections.abc import Callable + +import numpy as np +import tensorflow as tf +import tensorflow_hub as hub + +DEFAULT_ENCODER_URL = "https://tfhub.dev/google/sentence-t5/st5-base/1" + + +class EmbedderST5(Callable): + """Embeds text using ST5.""" + + def __init__(self, hub_url=DEFAULT_ENCODER_URL): + self._encoder = hub.KerasLayer(hub_url) + + def __call__(self, text: str) -> np.ndarray: + english_sentences = tf.constant([text]) + (batched_embedding,) = self._encoder(english_sentences) + return np.squeeze(batched_embedding.numpy()) diff --git a/concordia/associative_memory/formative_memories.py b/concordia/associative_memory/formative_memories.py new file mode 100644 index 00000000..a74bcf5e --- /dev/null +++ b/concordia/associative_memory/formative_memories.py @@ -0,0 +1,208 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""This is a factory for generating memories for concordia agents.""" + +from collections.abc import Callable, Iterable, Sequence +import dataclasses +import datetime +import re +from typing import Any +from concordia.associative_memory import associative_memory +from concordia.associative_memory import importance_function +from concordia.document import interactive_document +from concordia.language_model import language_model +from dateutil.relativedelta import relativedelta # pylint: disable=g-importing-member + + +DEFAULT_DOB = datetime.datetime(year=1984, month=7, day=3, hour=0, minute=0) +DEFAULT_FORMATIVE_AGES = (3, 7, 12, 16, 21) +DEFAULT_IMPORTANT_MODEL = importance_function.ConstantImportanceModel() + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class AgentConfig: + """A card that describes a player. + + Attributes: + name: name of the agent. + gender: the gender of the agent. + traits: any traits to use while generating formative memories. For example, + big five. + context: agent formative memories will be generated with this context + goal: defines agents goal. Can be left blank if not used. + date_of_birth: the date of birth for the agent. + formative_ages: ages at which the formative episodes will be created + formative_memory_importance: the importance value of formative memories. + extras: a field for the user to keep any experiment specific data they need + to define an agent + """ + + name: str + gender: str + traits: str + context: str = '' + goal: str = '' + date_of_birth: datetime.datetime = DEFAULT_DOB + formative_ages: Iterable[int] = DEFAULT_FORMATIVE_AGES + formative_memory_importance: float = 1.0 + extras: dict[str, Any] = dataclasses.field(default_factory=dict) + + +class FormativeMemoryFactory: + """Generator of formative memories.""" + + def __init__( + self, + *, + model: language_model.LanguageModel, + shared_memories: Sequence[str] = (), + blank_memory_factory_call: Callable[ + [], associative_memory.AssociativeMemory + ], + ): + self._model = model + self._blank_memory_factory_call = blank_memory_factory_call + self._shared_memories = shared_memories + + def make_backstory( + self, name: str, gender: str, traits_description: str, context: str | None + ) -> str: + """Creates a backstory of an agent based on traits. + + Args: + name: name of the agent + gender: gender of the agent + traits_description: descriptive traits of an agent, for example big five + context: any context to add to the generation, i.e. genre + + Returns: + Descriptive text about the agent + """ + prompt = interactive_document.InteractiveDocument(self._model) + + if context: + prompt.statement(context) + question = ( + f'Given the following trats:\n{str(traits_description)}' + f'\n create a backstory about a {gender} character called {name}.' + ' Write a summary of the person:' + ' what their job is, what a typical day is is like, what are their' + ' goals, desires, hopes, dreams, and aspirations. Also write about' + ' their duties, responsibilities, and obligations. What gives them joy' + ' and what are they afraid of. Write about their friends and what they' + ' like to do. Also write about their current concerns.' + ) + if context: + question += f'Take into account the following context: {context}' + result = prompt.open_question( + question, + max_characters=8000, + max_tokens=8000, + terminators=[], + ) + result = re.sub(r'\.\s', '.\n', result) + + query = '\n'.join([ + ( + 'Replace all the pronounce in the following text with the name' + f' {name}.' + ), + 'The text:', + result, + ]) + + description = self._model.sample_text(query) + description = re.sub(r'\.\s', '.\n', description) + + return description + + def make_memories( + self, + agent_config: AgentConfig, + ) -> associative_memory.AssociativeMemory: + """Creates agent memory from the agent card.""" + + mem = self._blank_memory_factory_call() + # All players share generic memories. + for item in self._shared_memories: + mem.add(item) + + context = agent_config.context + if agent_config.goal: + context += '\n' + agent_config.goal + + self.add_memories(memory=mem, agent_config=agent_config) + + if context: + context_items = context.split('\n') + for item in context_items: + if item: + mem.add(item, importance=1.0) + + return mem + + def add_memories( + self, + memory: associative_memory.AssociativeMemory, + agent_config: AgentConfig, + ) -> None: + """Creates formative memories of the agent at specific ages based on traits. + + First, a series of descriptive statements will be generated and based on + them the formative episodes. There is an option to add description to memory + as well. + Args: + memory: the memory structure to fill + agent_config: the card describing the agent properties + """ + description = self.make_backstory( + agent_config.name, + agent_config.gender, + agent_config.traits, + agent_config.context, + ) + prompt = interactive_document.InteractiveDocument(self._model) + prompt.statement('Context: ' + description) + + for episode_age in agent_config.formative_ages: + question = ( + 'Given the context above, come up with a formative episode at the ' + + f'age of {episode_age}, which is consistent with' + f" {agent_config.name}'s " + + f"personality. Describe the episode from {agent_config.name}'s" + ' perspective ' + + 'using third-person limited point of view. Mention their age at ' + + 'the time. Use past tense. Write no more than three sentences.' + ) + if agent_config.context: + question += ( + '\nThe generated episode should be specifically related to some' + f' aspect of the following context: "{agent_config.context}"' + ) + + episode = prompt.open_question( + question, + max_characters=8000, + max_tokens=8000, + terminators=[], + ) + memory.add( + episode, + tags=['episode'], + timestamp=agent_config.date_of_birth + + relativedelta(years=episode_age), + importance=agent_config.formative_memory_importance, + ) diff --git a/concordia/associative_memory/importance_function.py b/concordia/associative_memory/importance_function.py new file mode 100644 index 00000000..8b8e3f0f --- /dev/null +++ b/concordia/associative_memory/importance_function.py @@ -0,0 +1,163 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Memory importance function.""" + +import abc +from collections.abc import Sequence + +from concordia.document import interactive_document +from concordia.language_model import language_model + + +DEFAULT_IMPORTANCE_SCALE = tuple(range(4)) + + +class ImportanceModel(metaclass=abc.ABCMeta): + """Memory importance module for generative agents.""" + + @abc.abstractmethod + def importance(self, memory: str) -> float: + """Computes importance of a memory. + + Args: + memory: a memory (text) to compute importance of + + Returns: + Value of importance in the [0,1] interval + """ + + raise NotImplementedError + + +class AgentImportanceModel(ImportanceModel): + """Memory importance function for simulacra agents. + + Importance is defined as poignancy of the memory according to LLM. + """ + + def __init__( + self, + model: language_model.LanguageModel, + importance_scale: Sequence[float] = DEFAULT_IMPORTANCE_SCALE, + ): + """Initialises an instance. + + Args: + model: LLM + importance_scale: a scale of poignancy + """ + self._model = model + self._importance_scale = [str(i) for i in sorted(importance_scale)] + + def importance(self, memory: str) -> float: + """Computes importance of a memory by quering LLM. + + Args: + memory: memory to compute importance of + + Returns: + Value of importance in the [0,1] interval + """ + zero, *_, one = self._importance_scale + prompt = interactive_document.InteractiveDocument(self._model) + action = prompt.multiple_choice_question( + f"On the scale of {zero} to" + f" {one}, where {zero} is" + " purely mundane (e.g., brushing teeth, making bed) and" + f" {one} is extremely poignant (e.g., a break" + " up, college acceptance), rate the likely poignancy of the following" + " piece of memory.\nMemory:" + + memory + + "\nRating: ", + answers=self._importance_scale, + ) + return action / (len(self._importance_scale) - 1) + + +class GMImportanceModel(ImportanceModel): + """Memory importance function for a game master. + + Importance is defined as importance of the memory according to LLM. + """ + + def __init__( + self, + model: language_model.LanguageModel, + importance_scale: Sequence[float] = DEFAULT_IMPORTANCE_SCALE, + ): + """Initialises an instance. + + Args: + model: LLM + importance_scale: a scale of poignancy + """ + self._model = model + self._importance_scale = [str(i) for i in sorted(importance_scale)] + + def importance(self, memory: str) -> float: + """Computes importance of a memory by quering LLM. + + Args: + memory: memory to compute importance of + + Returns: + Value of importance + """ + zero, *_, one = self._importance_scale + chain_of_thought = interactive_document.InteractiveDocument(self._model) + action = chain_of_thought.multiple_choice_question( + f"On the scale of {zero} to " + f"{one}, where {zero} is purely mundane " + f"(e.g., wind blowing, bus arriving) and {one} is " + "extremely poignant (e.g., an earthquake, end of war, " + "revolution), rate the likely poignancy of the " + "following event.\nEvent:" + + memory + + "\nRating: ", + answers=self._importance_scale, + ) + return action / (len(self._importance_scale) - 1) + + +class ConstantImportanceModel(ImportanceModel): + """Memory importance function that always returns a constant. + + This is useful for debugging since it doesn't call LLM. + """ + + def __init__( + self, + fixed_importance: float = 1.0, + ): + """Initialises an instance. + + Args: + fixed_importance: the constant to return + """ + self._fixed_importance = fixed_importance + + def importance(self, memory: str) -> float: + """Computes importance of a memory by quering LLM. + + Args: + memory: memory to compute importance of + + Returns: + Value of importance + """ + del memory + + return self._fixed_importance diff --git a/concordia/clocks/__init__.py b/concordia/clocks/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/concordia/clocks/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/concordia/clocks/game_clock.py b/concordia/clocks/game_clock.py new file mode 100644 index 00000000..b2d28457 --- /dev/null +++ b/concordia/clocks/game_clock.py @@ -0,0 +1,163 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""A clock for synchronising simulacra.""" + +from collections.abc import Sequence +import contextlib +import datetime + +from concordia.typing import clock + +_DEFAULT_STEP_SIZE = datetime.timedelta(minutes=1) + + +class FixedIntervalClock(clock.GameClock): + """A fixed-interval clock for synchronising simulacra.""" + + def __init__( + self, + start: datetime.datetime | None = None, + step_size: datetime.timedelta = _DEFAULT_STEP_SIZE, + ): + """Initializes the clock. + + Args: + start: The start time of the clock. If None, the current time is used. + step_size: The step size of the clock. + """ + if start is None: + self._start = datetime.datetime.now() + else: + self._start = start + self._step_size = step_size + self._step = 0 + + def advance(self): + """Advances time by step_size.""" + self._step += 1 + + def set(self, time: datetime.datetime): + self._step = (time - self._start) // self._step_size + + def now(self) -> datetime.datetime: + return self._start + self._step * self._step_size + + def get_step_size(self) -> datetime.timedelta: + return self._step_size + + def get_step(self) -> int: + return self._step + + def current_time_interval_str(self) -> str: + this_time = self.now() + next_time = this_time + self._step_size + + time_string = this_time.strftime( + ' %d %b %Y [%H:%M:%S - ' + ) + next_time.strftime('%H:%M:%S]') + return time_string + + +class MultiIntervalClock(clock.GameClock): + """A multi-interval clock for synchronising simulacra. + + This clock takes in multiple step sizes, which can be switched between using + gear_up and gear_down. Important: when advancing, the clock switches to the + next step in the current gear and zeros all steps of all higher gears. For + example if step sizes are 1 hour and 10 minutes and current time is 15:40, + then going back to lowest gear and advancing will yield 16:00 and not 16:40. + """ + + def __init__( + self, + start: datetime.datetime | None = None, + step_sizes: Sequence[datetime.timedelta] = (_DEFAULT_STEP_SIZE,), + ): + """Initializes the clock. + + Args: + start: The start time of the clock. If None, the current time is used. + step_sizes: The step sizes of the clock. + + Raises: + RuntimeError: If step_sizes are not sorted from lowest to highest. + """ + if start is None: + self._start = datetime.datetime.now() + else: + self._start = start + + # the default makes it a fixed interval clock + if step_sizes != sorted(step_sizes, reverse=True): + raise RuntimeError('Step sizes have to be sorted from lowest to highest.') + + self._step_sizes = step_sizes + self._steps = [0] * len(step_sizes) + self._current_gear = 0 + + def gear_up(self) -> None: + if self._current_gear + 1 >= len(self._step_sizes): + raise RuntimeError('Already in highest gear.') + self._current_gear += 1 + + def gear_down(self) -> None: + if self._current_gear == 0: + raise RuntimeError('Already in lowest gear.') + self._current_gear -= 1 + + @contextlib.contextmanager + def higher_gear(self): + self.gear_up() + try: + yield + finally: + self.gear_down() + + def advance(self): + """Advances time by step_size.""" + self._steps[self._current_gear] += 1 + for gear in range(self._current_gear + 1, len(self._step_sizes)): + self._steps[gear] = 0 + self.set(self.now()) # resolve the higher gear running over the lower + + def set(self, time: datetime.datetime): + remainder = time - self._start + for gear, step_size in enumerate(self._step_sizes): + self._steps[gear] = remainder // step_size + remainder -= step_size * self._steps[gear] + + def now(self) -> datetime.datetime: + output = self._start + for gear, step_size in enumerate(self._step_sizes): + output += self._steps[gear] * step_size + return output + + def get_step_size(self) -> datetime.timedelta: + return self._step_sizes[self._current_gear] + + def get_step(self) -> int: + """Returns the current step in the lowest gear.""" + # this is used for logging, so makes sense to use lowest gear + return self._steps[0] + + def current_time_interval_str(self) -> str: + this_time = self.now() + next_time = this_time + self._step_sizes[self._current_gear] + + time_string = this_time.strftime( + ' %d %b %Y [%H:%M - ' + ) + next_time.strftime('%H:%M]') + return time_string diff --git a/concordia/clocks/game_clock_test.py b/concordia/clocks/game_clock_test.py new file mode 100644 index 00000000..65a85d9c --- /dev/null +++ b/concordia/clocks/game_clock_test.py @@ -0,0 +1,55 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import datetime +from absl.testing import absltest +from absl.testing import parameterized +from concordia.clocks import game_clock + + +class GameClockTest(parameterized.TestCase): + + def test_advance(self): + times = [] + clock = game_clock.MultiIntervalClock( + start=datetime.datetime(hour=8, year=2024, month=9, day=1), + step_sizes=[ + datetime.timedelta(hours=1), + datetime.timedelta(minutes=10), + ], + ) + clock.advance() + times.append(clock.now()) + + with clock.higher_gear(): + for _ in range(7): + clock.advance() + times.append(clock.now()) + times.append(clock.now()) + + clock.advance() + times.append(clock.now()) + + expected = [ + datetime.datetime(hour=9, year=2024, month=9, day=1), + datetime.datetime(minute=10, hour=10, year=2024, month=9, day=1), + datetime.datetime(minute=10, hour=10, year=2024, month=9, day=1), + datetime.datetime(minute=0, hour=11, year=2024, month=9, day=1), + ] + self.assertEqual(times, expected) + + +if __name__ == '__main__': + absltest.main() diff --git a/concordia/document/__init__.py b/concordia/document/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/concordia/document/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/concordia/document/document.py b/concordia/document/document.py new file mode 100644 index 00000000..4af34f8c --- /dev/null +++ b/concordia/document/document.py @@ -0,0 +1,182 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""A document that is built from a chain of text.""" + +from collections.abc import Collection, Iterable, Iterator, Set +import contextlib +import dataclasses +from typing import TypeVar + +T = TypeVar('T') + + +@dataclasses.dataclass(frozen=True) +class Content: + """Content appended to a document. + + Attributes: + text: the text of the content + hidden: if True the content should be hidden from the reader + tags: tags provided at time this was written to the document + """ + text: str + _: dataclasses.KW_ONLY + tags: Set[str] = frozenset() + + # TODO: b/311191278 - implement _repr_pretty_, _repr_html_, _repr_markdown_ + + def __post_init__(self): + object.__setattr__(self, 'tags', frozenset(self.tags)) + + def __str__(self): + return self.text + + +class Document: + """A document of text and associated metadata.""" + + def __init__(self, contents: Iterable[Content] = ()) -> None: + """Initializes the document. + + Args: + contents: Initial contents of the document. + """ + # TODO: b/311191572 - be more efficient if contents is a tupel iter. + self._contents = tuple(contents) + + # TODO: b/311191905 - implement __iadd__, __add__? + # TODO: b/311191278 - implement _repr_pretty_, _repr_html_, _repr_markdown_ + + def __iter__(self) -> Iterator[Content]: + """Yields the contents in the document.""" + yield from self._contents + + def __eq__(self, other): + """Returns True if other is a Document with identical contents.""" + if not isinstance(other, type(self)): + return NotImplemented + else: + return self._contents == other._contents + + def __ne__(self, other): + """Returns True if other is not a Document or has different contents.""" + return not self.__eq__(other) + + def contents(self) -> tuple[Content, ...]: + """Returns the contents in the document.""" + return self._contents + + def text(self) -> str: + """Returns all the text in the document.""" + return ''.join(content.text for content in self) + + def view( + self, + include_tags: Iterable[str] = (), + exclude_tags: Iterable[str] = (), + ) -> 'View': + """Returns a view of the document. + + Args: + include_tags: specifies which tags to include in the view. + exclude_tags: specifies which tags to exclude from the view. + """ + return View(self, include_tags=include_tags, exclude_tags=exclude_tags) + + def clear(self): + """Clears the document.""" + self._contents = () + + def append( + self, + text: str, + *, + tags: Collection[str] = (), + ) -> None: + """Appends text to the document.""" + text = Content(text=text, tags=frozenset(tags)) + self._contents += (text,) + + def extend(self, contents: Iterable[Content]) -> None: + """Extends the document with the provided contents.""" + self._contents += tuple(contents) + + def copy(self) -> 'Document': + """Returns a copy of the document.""" + return Document(self.contents()) + + def new(self: T) -> T: + """Returns an empty copy of this document.""" + document = self.copy() + document.clear() + return document + + @contextlib.contextmanager + def edit(self: T) -> Iterator[T]: + """Edits the current document. + + Creates a edit based on the current document. Once the context is completed, + the edit will be committed to the document. If you wish not to commit the + edit call edit.clear() before leavign the context. + + Yields: + The document being edited. + """ + edit = self.new() + yield edit + self.extend(edit.contents()) + + +class View: + """A view of a document.""" + + def __init__( + self, + document: Document, + include_tags: Iterable[str] = (), + exclude_tags: Iterable[str] = (), + ) -> None: + """Initializes the instance. + + Args: + document: the base document on which to add edits. + include_tags: specifies which tags to include in the view. + exclude_tags: specifies which tags to exclude from the view. + """ + self._include_tags = frozenset(include_tags) + self._exclude_tags = frozenset(exclude_tags) + common_tags = self._include_tags & self._exclude_tags + if common_tags: + raise ValueError(f'Cannot both include and exclude tags {common_tags!r}') + self._document = document + + def __iter__(self) -> Iterator[Content]: + """Yields the contents in the view.""" + for content in self._document: + if self._exclude_tags and content.tags & self._exclude_tags: + continue + elif self._include_tags and not content.tags & self._include_tags: + continue + else: + yield content + + def contents(self) -> tuple[Content, ...]: + """Yields the contents in the view.""" + return tuple(self) + + def text(self) -> str: + """Returns the contents of the document as a single string.""" + return ''.join(content.text for content in self) diff --git a/concordia/document/document_test.py b/concordia/document/document_test.py new file mode 100644 index 00000000..4062da43 --- /dev/null +++ b/concordia/document/document_test.py @@ -0,0 +1,179 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from absl.testing import absltest +from absl.testing import parameterized +from concordia.document import document + + +class DocumentTest(parameterized.TestCase): + + def test_init(self): + doc = document.Document() + with self.subTest('text'): + self.assertEmpty(doc.text()) + with self.subTest('contents'): + self.assertEmpty(doc.contents()) + + def test_append(self): + doc = document.Document() + doc.append('one', tags=['a', 'b']) + doc.append('two', tags=['b', 'c']) + + with self.subTest('text'): + self.assertEqual(doc.text(), 'onetwo') + + with self.subTest('contents'): + expected = [ + document.Content(text='one', tags=frozenset({'a', 'b'})), + document.Content(text='two', tags=frozenset({'b', 'c'})), + ] + self.assertSequenceEqual(doc.contents(), expected) + + with self.subTest('document'): + self.assertNotEmpty(doc.text()) + + def test_clear(self): + doc = document.Document() + doc.append('one', tags=['a', 'b']) + doc.append('two', tags=['b', 'c']) + doc.clear() + + with self.subTest('text'): + self.assertEmpty(doc.text()) + with self.subTest('contents'): + self.assertEmpty(doc.contents()) + with self.subTest('document'): + self.assertEmpty(doc.text()) + + def test_view(self): + doc = document.Document() + view = doc.view() + initial_text = view.text() + initial_contents = view.contents() + doc.append('one', tags=['a', 'b']) + doc.append('two', tags=['b', 'c']) + doc.append('three', tags=['c', 'd']) + final_text = view.text() + final_contents = view.contents() + + with self.subTest('initial_text'): + self.assertEqual(initial_text, '') + with self.subTest('initial_contents'): + self.assertEmpty(initial_contents) + with self.subTest('final_text'): + self.assertEqual(final_text, 'onetwothree') + with self.subTest('final_contents'): + self.assertSequenceEqual( + final_contents, + [ + document.Content(text='one', tags=frozenset({'a', 'b'})), + document.Content(text='two', tags=frozenset({'b', 'c'})), + document.Content(text='three', tags=frozenset({'c', 'd'})), + ], + ) + + def test_filtered_view(self): + doc = document.Document() + view = doc.view(include_tags={'b'}, exclude_tags={'a'}) + initial_text = view.text() + initial_contents = view.contents() + doc.append('one', tags=['a', 'b']) + doc.append('two', tags=['b', 'c']) + doc.append('three', tags=['c', 'd']) + final_text = view.text() + final_contents = view.contents() + + with self.subTest('initial_text'): + self.assertEmpty(initial_text) + with self.subTest('initial_contents'): + self.assertEmpty(initial_contents) + with self.subTest('final_text'): + self.assertEqual(final_text, 'two') + with self.subTest('final_contents'): + self.assertSequenceEqual( + final_contents, + [ + document.Content(text='two', tags=frozenset({'b', 'c'})), + ], + ) + + def test_edit(self): + doc = document.Document() + doc.append('one', tags=['a', 'b']) + doc.append('two', tags=['b', 'c']) + doc.append('three', tags=['c', 'd']) + doc_before_edit = doc.contents() + with doc.edit() as edit: + edit.append('four', tags=['d', 'e']) + doc_during_edit = doc.contents() + edit_contents = edit.contents() + doc_after_edit = doc.contents() + + with self.subTest('doc_during_edit'): + self.assertEqual(doc_during_edit, doc_before_edit) + with self.subTest('doc_after_edit'): + self.assertEqual(doc_after_edit, doc_before_edit + edit_contents) + with self.subTest('edit_contents'): + self.assertSequenceEqual( + edit_contents, + [ + document.Content(text='four', tags=frozenset({'d', 'e'})), + ], + ) + + def test_edit_rollback(self): + doc = document.Document() + doc.append('one', tags=['a', 'b']) + doc.append('two', tags=['b', 'c']) + doc.append('three', tags=['c', 'd']) + doc_before_edit = doc.contents() + with doc.edit() as edit: + edit.append('four', tags=['d', 'e']) + edit.clear() + doc_after_edit = doc.contents() + + with self.subTest('doc_after_no_edit'): + self.assertEqual(doc_after_edit, doc_before_edit) + with self.subTest('empty_edit'): + self.assertEmpty(edit.contents()) + + def test_eq(self): + doc = document.Document() + doc.append('one', tags=['a', 'b']) + doc.append('two', tags=['b', 'c']) + doc.append('three', tags=['c', 'd']) + copy = doc.copy() + self.assertEqual(doc, copy) + + def test_ne(self): + doc = document.Document() + doc.append('one', tags=['a', 'b']) + doc.append('two', tags=['b', 'c']) + copy = doc.copy() + doc.append('three', tags=['c', 'd']) + self.assertNotEqual(doc, copy) + + def test_new(self): + doc = document.Document() + doc.append('one', tags=['a', 'b']) + doc.append('two', tags=['b', 'c']) + doc.append('three', tags=['c', 'd']) + new_doc = doc.new() + self.assertEqual(new_doc, document.Document()) + + +if __name__ == '__main__': + absltest.main() diff --git a/concordia/document/interactive_document.py b/concordia/document/interactive_document.py new file mode 100644 index 00000000..5ccf0474 --- /dev/null +++ b/concordia/document/interactive_document.py @@ -0,0 +1,216 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Utilities for chain-of-thought prompting.""" + +from collections.abc import Collection, Iterable, Iterator, Sequence +import contextlib + +from concordia.document import document +from concordia.language_model import language_model +import numpy as np + +DEFAULT_MAX_CHARACTERS = 200 +DEFAULT_MAX_TOKENS = DEFAULT_MAX_CHARACTERS // 4 + +DEBUG_TAG = 'debug' +STATEMENT_TAG = 'statement' +QUESTION_TAG = 'question' +RESPONSE_TAG = 'response' +MODEL_TAG = 'model' +INTERACTIVE_TAGS = frozenset( + {DEBUG_TAG, STATEMENT_TAG, QUESTION_TAG, RESPONSE_TAG, MODEL_TAG} +) + + +_YESNO = ['No', 'Yes'] + + +def _letters(): + """Yields the letters from a to z.""" + yield from (chr(ord('a') + i) for i in range(26)) + + +class InteractiveDocument(document.Document): + """A document formed by interaction with a language model.""" + + def __init__( + self, + model: language_model.LanguageModel, + contents: Iterable[document.Content] = (), + rng: np.random.Generator | None = None, + ) -> None: + """Initializes the instance. + + Args: + model: language model to interact with. + contents: initial contents of the document. + rng: randomization source. + """ + super().__init__(contents) + if rng: + self._rng = rng + else: + self._rng = np.random.default_rng() + self._model = model + self._model_view = self.view() + # TODO: b/311191701 - debug log some useful stuff? + + def view( + self, + include_tags: Iterable[str] = (), + exclude_tags: Iterable[str] = (DEBUG_TAG,), + ) -> document.View: + """Returns a view of the document. + + Args: + include_tags: specifies which tags to include in the view. + exclude_tags: specifies which tags to exclude from the view. + """ + return super().view(include_tags=include_tags, exclude_tags=exclude_tags) + + def copy(self) -> 'InteractiveDocument': + """See base class.""" + # TODO: b/311192069 - what about rng? + return InteractiveDocument( + model=self._model, contents=self.contents(), rng=self._rng + ) + + @contextlib.contextmanager + def edit(self) -> Iterator['InteractiveDocument']: + """See base class.""" + # TODO: b/311192069 - what about rng? + edit = InteractiveDocument(model=self._model, rng=self._rng) + yield edit + self.extend(edit.contents()) + + def debug( + self, text: str, *, tags: Collection[str] = (), end: str = '\n' + ) -> None: + """Appends debug text to the document. + + Args: + text: text to append. + tags: additional tags for appended text. + end: appended to `text`. + """ + self.append(text + end, tags=[DEBUG_TAG, *tags]) + + def statement( + self, text: str, *, tags: Collection[str] = (), end: str = '\n' + ) -> None: + """Appends a statement to the document. + + Args: + text: text to append. + tags: additional tags for appended text. + end: appended to `text`. + """ + self.append(text + end, tags=[STATEMENT_TAG, *tags]) + + def _question( + self, text: str, *, tags: Collection[str] = (), end: str = '' + ) -> None: + """Appends a question to the document.""" + self.append(text + end, tags=[QUESTION_TAG, *tags]) + + def _response( + self, text: str, *, tags: Collection[str] = (), end: str = '' + ) -> None: + """Appends a response to the document.""" + self.append(text + end, tags=[RESPONSE_TAG, *tags]) + + def _model_response( + self, text: str, *, tags: Collection[str] = (), end: str = '' + ) -> None: + """Appends a response to the document that was generated by the model.""" + self.append(text + end, tags=[RESPONSE_TAG, MODEL_TAG, *tags]) + + def open_question( + self, + question: str, + *, + answer_prefix: str = '', + answer_suffix: str = '', + max_tokens: int = DEFAULT_MAX_TOKENS, + max_characters: int = DEFAULT_MAX_CHARACTERS, + terminators: Collection[str] = ('\n',), + ) -> str: + """Asks the agent an open question and appends it to the document. + + Args: + question: the question to ask. + answer_prefix: a prefix to append to the model's prompt. + answer_suffix: a suffix to append to the model's response. + max_tokens: the maximum number of tokens to sample from the model. + max_characters: the maximum number of characters to sample from the model. + terminators: strings that must not be present in the model's response. If + emitted by the model the response will be truncated before them. + + Returns: + The agents truncated response. + """ + self._question(f'Question: {question}\n') + self._response(f'Answer: {answer_prefix}') + response = self._model.sample_text( + prompt=self._model_view.text(), + max_tokens=max_tokens, + max_characters=max_characters, + terminators=terminators, + ) + self._model_response(response) + self._response(f'{answer_suffix}\n') + return response + + def multiple_choice_question( + self, question: str, answers: Sequence[str] + ) -> int: + """Presents a multiple choice to the agent. + + Args: + question: the question to ask the agent. + answers: the choice of answers + + Returns: + The index of the sampled answer. + """ + original_indices = self._rng.permutation(len(answers)) + options = {key: answers[i] for key, i in zip(_letters(), original_indices)} + self._question(f'Question: {question}\n') + for key, option in options.items(): + self._question(f' ({key}) {option}\n') + + self._response('Answer: (') + idx, response, debug = self._model.sample_choice( + prompt=self._model_view.text(), + responses=list(options.keys()), + ) + self._model_response(response) + self._response(')\n') + self.debug(f'[{debug}]') + return original_indices[idx] + + def yes_no_question(self, question: str) -> bool: + """Presents a yes/no question to the agent. + + Args: + question: the question to ask the agent. + + Returns: + True iff the answer was answered with Yes. + """ + return self.multiple_choice_question(question, _YESNO) == _YESNO.index( + 'Yes' + ) diff --git a/concordia/document/interactive_document_test.py b/concordia/document/interactive_document_test.py new file mode 100644 index 00000000..c38d2544 --- /dev/null +++ b/concordia/document/interactive_document_test.py @@ -0,0 +1,202 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +from concordia.document import document +from concordia.document import interactive_document +from concordia.language_model import language_model +import numpy as np + + +DEBUG = functools.partial(document.Content, tags=frozenset({'debug'})) +STATEMENT = functools.partial(document.Content, tags=frozenset({'statement'})) +QUESTION = functools.partial(document.Content, tags=frozenset({'question'})) +RESPONSE = functools.partial(document.Content, tags=frozenset({'response'})) +MODEL_RESPONSE = functools.partial( + document.Content, tags=frozenset({'response', 'model'}) +) + + +class InteractiveDocumentTest(parameterized.TestCase): + + def test_open_question(self): + model = mock.create_autospec( + language_model.LanguageModel, instance=True, spec_set=True + ) + model.sample_text.return_value = 'This is a long answer' + + doc = interactive_document.InteractiveDocument(model) + doc.statement('Hello') + response = doc.open_question( + question='What is 1+1?', + answer_prefix='Well...', + max_tokens=mock.sentinel.max_tokens, + max_characters=mock.sentinel.max_characters, + terminators=mock.sentinel.terminators, + ) + + with self.subTest('response'): + self.assertEqual(response, 'This is a long answer') + + with self.subTest('model'): + prompt = """Hello +Question: What is 1+1? +Answer: Well...""" + model.sample_text.assert_called_once_with( + prompt=prompt, + max_tokens=mock.sentinel.max_tokens, + max_characters=mock.sentinel.max_characters, + terminators=mock.sentinel.terminators, + ) + + with self.subTest('text'): + expected = """Hello +Question: What is 1+1? +Answer: Well...This is a long answer +""" + self.assertEqual(doc.text(), expected) + + with self.subTest('contents'): + expected = ( + STATEMENT('Hello\n'), + QUESTION('Question: What is 1+1?\n'), + RESPONSE('Answer: Well...'), + MODEL_RESPONSE('This is a long answer'), + RESPONSE('\n'), + ) + self.assertEqual(doc.contents(), expected) + + def test_multiple_choice_question(self): + model = mock.create_autospec( + language_model.LanguageModel, instance=True, spec_set=True + ) + model.sample_choice.return_value = (2, 'c', mock.sentinel.debug) + rng = mock.create_autospec( + np.random.Generator, instance=True, spec_set=True + ) + rng.permutation.return_value = np.arange(3)[::-1] + + doc = interactive_document.InteractiveDocument(model, rng=rng) + doc.statement('Hello') + response = doc.multiple_choice_question( + question='What is 1+1?', + answers=['1', '2', '3'], + ) + + with self.subTest('response'): + self.assertEqual(response, 0) + + with self.subTest('model'): + prompt = """Hello +Question: What is 1+1? + (a) 3 + (b) 2 + (c) 1 +Answer: (""" + model.sample_choice.assert_called_once_with( + prompt=prompt, responses=['a', 'b', 'c'] + ) + + with self.subTest('text'): + expected = """Hello +Question: What is 1+1? + (a) 3 + (b) 2 + (c) 1 +Answer: (c) +[sentinel.debug] +""" + self.assertEqual(doc.text(), expected) + + with self.subTest('contents'): + expected = ( + STATEMENT('Hello\n'), + QUESTION('Question: What is 1+1?\n'), + QUESTION(' (a) 3\n'), + QUESTION(' (b) 2\n'), + QUESTION(' (c) 1\n'), + RESPONSE('Answer: ('), + MODEL_RESPONSE('c'), + RESPONSE(')\n'), + DEBUG('[sentinel.debug]\n'), + ) + self.assertSequenceEqual(doc.contents(), expected) + + def test_debug_hidden_from_default_view(self): + model = mock.create_autospec( + language_model.LanguageModel, instance=True, spec_set=True + ) + model.sample_choice.return_value = (2, 'c', mock.sentinel.debug) + rng = mock.create_autospec( + np.random.Generator, instance=True, spec_set=True + ) + rng.permutation.return_value = np.arange(3)[::-1] + + doc = interactive_document.InteractiveDocument(model, rng=rng) + doc.statement('Hello') + doc.multiple_choice_question( + question='What is 1+1?', + answers=['1', '2', '3'], + ) + + with self.subTest('view'): + expected = """Hello +Question: What is 1+1? + (a) 3 + (b) 2 + (c) 1 +Answer: (c) +""" + self.assertEqual(doc.view().text(), expected) + + def test_yes_no_question_answer_yes(self): + model = mock.create_autospec( + language_model.LanguageModel, instance=True, spec_set=True + ) + + rng = mock.create_autospec( + np.random.Generator, instance=True, spec_set=True + ) + rng.permutation.return_value = [0, 1] + + doc = interactive_document.InteractiveDocument(model, rng=rng) + doc.statement('Hello') + model.sample_choice.return_value = (1, 'b', mock.sentinel.debug) + response = doc.yes_no_question(question='Does 1+1 equal 2?') + self.assertTrue(response) + + def test_yes_no_question_answer_no(self): + model = mock.create_autospec( + language_model.LanguageModel, instance=True, spec_set=True + ) + + rng = mock.create_autospec( + np.random.Generator, instance=True, spec_set=True + ) + rng.permutation.return_value = [0, 1] + + doc = interactive_document.InteractiveDocument(model, rng=rng) + doc.statement('Hello') + model.sample_choice.return_value = (0, 'a', mock.sentinel.debug) + response = doc.yes_no_question(question='Does 1+1 equal 3?') + self.assertFalse(response) + + +if __name__ == '__main__': + absltest.main() diff --git a/concordia/environment/__init__.py b/concordia/environment/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/concordia/environment/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/concordia/environment/components/__init__.py b/concordia/environment/components/__init__.py new file mode 100644 index 00000000..130c16db --- /dev/null +++ b/concordia/environment/components/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Library of components for generative game master and agents.""" + +from concordia.environment.components import conversation +from concordia.environment.components import direct_effect +from concordia.environment.components import inventory +from concordia.environment.components import player_status +from concordia.environment.components import relevant_events +from concordia.environment.components import schedule +from concordia.environment.components import time_display diff --git a/concordia/environment/components/conversation.py b/concordia/environment/components/conversation.py new file mode 100644 index 00000000..e175a819 --- /dev/null +++ b/concordia/environment/components/conversation.py @@ -0,0 +1,322 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Externality for the Game Master, which generates conversations.""" + +from collections.abc import Sequence + +from concordia.agents import basic_agent +from concordia.agents import components as sim_components +from concordia.associative_memory import associative_memory +from concordia.associative_memory import blank_memories +from concordia.clocks import game_clock +from concordia.document import interactive_document +from concordia.environment.scenes import conversation as conversation_scene +from concordia.language_model import language_model +from concordia.typing import clock as clock_lib +from concordia.typing import component +from concordia.typing import metric +from concordia.utils import helper_functions +import termcolor + + +class Conversation(component.Component): + """Conversation generator.""" + + def __init__( + self, + players: Sequence[basic_agent.BasicAgent], + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + clock: game_clock.MultiIntervalClock, + burner_memory_factory: blank_memories.MemoryFactory, + cap_nonplayer_characters: int = 3, + game_master_instructions: str = '', + shared_context: str = '', + measurements: Sequence[metric.Metric] | None = None, + components: Sequence[component.Component] | None = None, + allow_self_talk: bool = False, + verbose: bool = False, + print_colour: str = 'magenta', + ): + """Initializes the generator of conversations. + + Args: + players: A list of players to generate conversations for. + model: A language model to use for generating utterances. + memory: GM memory, used to add the summary of the conversation + clock: multi intercal game clock. + burner_memory_factory: a memory factory to create temporary memory for + npcs and conversation gm + cap_nonplayer_characters: The maximum number of non-player characters + allowed in the conversation. + game_master_instructions: A string to use as the game master instructions. + shared_context: A string to use as the generic context for the NPCs. + measurements: metrics to pass into the conversation GM + components: components that contextualise the conversation + allow_self_talk: allow players to have a conversation with themselves + verbose: Whether to print debug messages or not. + print_colour: colour in which to print logs + """ + self._players = players + self._model = model + self._cap_nonplayer_characters = cap_nonplayer_characters + self._game_master_instructions = game_master_instructions + self._shared_context = shared_context + self._history = [] + self._verbose = verbose + self._print_colour = print_colour + self._components = components or [] + self._clock = clock + self._burner_memory_factory = burner_memory_factory + self._memory = memory + self._measurements = measurements + self._allow_self_talk = allow_self_talk + self._all_player_names = [player.name for player in self._players] + self._min_speakers = 1 if self._allow_self_talk else 2 + + def name(self) -> str: + return 'Conversations' + + def get_history(self): + return self._history.copy() + + def get_last_log(self): + if self._history: + return self._history[-1].copy() + + def get_player_names(self): + return [player.name for player in self._players] + + def _log(self, entry): + print(termcolor.colored(entry, self._print_colour)) + + def _make_npc( + self, name: str, scene_clock: clock_lib.GameClock + ) -> basic_agent.BasicAgent: + context = ( + f'{name} is a non-player character. Everyone knows the' + f' following:\n{self._shared_context}' + ) + + mem = self._burner_memory_factory.make_blank_memory() + + npc = basic_agent.BasicAgent( + model=self._model, + memory=mem, + agent_name=name, + clock=scene_clock, + components=[ + sim_components.constant.ConstantConstruct( + name='Instructions:', state=self._game_master_instructions + ), + sim_components.constant.ConstantConstruct( + name='General knowledge:', state=context + ), + sim_components.observation.Observation(agent_name=name, memory=mem), + ], + verbose=True, + ) + npc.update() + return npc + + def _get_nonplayer_characters( + self, + prompt: interactive_document.InteractiveDocument, + scene_clock: clock_lib.GameClock, + ) -> list[basic_agent.BasicAgent]: + prompt = prompt.copy() + nonplayer_characters = [] + npcs_exist = prompt.yes_no_question( + 'Are there any non-player characters in the conversation?' + ) + + if npcs_exist: + npcs = prompt.open_question( + 'Provide the list of non-player characters in the conversation ' + + 'as a comma-separated list. For example: "bartender, merchant" ' + + 'or "accountant, pharmacist, fishmonger". Non-player ' + + 'characters should be named only by generic characteristics ' + + 'such as their profession or role (e.g. shopkeeper).' + ) + npc_names = helper_functions.extract_from_generated_comma_separated_list( + npcs + ) + if len(npc_names) > self._cap_nonplayer_characters: + npc_names = npc_names[: self._cap_nonplayer_characters] + + nonplayer_characters = [ + self._make_npc(name, scene_clock) for name in npc_names + ] + + return nonplayer_characters + + def _generate_convo_summary(self, convo: list[str]): + summary = self._model.sample_text( + '\n'.join( + convo + ['Summaries the conversation above in one sentence.'], + ), + max_characters=2000, + max_tokens=2000, + terminators=(), + ) + return summary + + def _who_talked( + self, + player_names_in_conversation: list[str], + nonplayers_in_conversation: list[basic_agent.BasicAgent], + ): + who_talked = ( + 'Summary of a conversation between ' + + ', '.join(player_names_in_conversation) + + '. ' + ) + if nonplayers_in_conversation: + who_talked = ( + who_talked + + 'Also present: ' + + ', '.join( + [ + npc_conversant.name + for npc_conversant in nonplayers_in_conversation + ] + ) + + '.' + ) + return who_talked + + def update_after_event( + self, + event_statement: str, + ) -> None: + """Potentially creates the conversation from an event statement. + + Args: + event_statement: A string describing the event. + + Returns: + A list of strings describing the conversation. + """ + document = interactive_document.InteractiveDocument(self._model) + player_names = self.get_player_names() + + for construct in self._components: + document.statement(construct.name() + ': ' + construct.state() + '\n') + + document.statement(f'Event: {event_statement}\n') + conversation_occurred = document.yes_no_question( + 'Does the event suggest anyone said anything or is about to speak?' + ) + conversation_summary = '' + if self._verbose: + self._log('\n Checking if conversation occurred.') + + conversation_log = { + 'date': self._clock.now(), + 'Event statement': event_statement, + 'Summary': 'No conversation occurred.', + } + + # if yes, then propagate the event + if conversation_occurred: + player_names_in_conversation = [] + if self._verbose: + self._log('\n Conversation occurred. ') + document.statement('Conversation occurred.') + for player_name in player_names: + in_conversation = helper_functions.filter_copy_as_statement( + document + ).yes_no_question( + 'Does the event description explicitly state that' + f' {player_name} took part in the conversation?' + ) + if in_conversation: + player_names_in_conversation.append(player_name) + if self._verbose: + self._log( + '\n Players in conversation:' + + ', '.join(player_names_in_conversation) + + '.\n' + ) + if self._verbose: + self._log(document.view().text()) + + if player_names_in_conversation: + players_in_conversation = [ + player + for player in self._players + if player.name in player_names_in_conversation + ] + + nonplayers_in_conversation = self._get_nonplayer_characters( + document, self._clock + ) + + # this ensures that npcs can't duplicate players due to LLM mistake + nonplayers_in_conversation = [ + player + for player in nonplayers_in_conversation + if player.name not in self._all_player_names + ] + total_speakers = len(nonplayers_in_conversation) + len( + players_in_conversation + ) + + if total_speakers < self._min_speakers: + self._history.append(conversation_log) + return + + convo_scene = conversation_scene.make_conversation_game_master( + players_in_conversation + nonplayers_in_conversation, + clock=self._clock, + model=self._model, + memory_factory=self._burner_memory_factory, + name='Conversation scene', + premise=event_statement, + measurements=self._measurements, + ) + with self._clock.higher_gear(): + scene_output = convo_scene.run_episode() + conversation_summary = self._generate_convo_summary(scene_output) + + for player in players_in_conversation: + player.observe(conversation_summary) + + who_talked = self._who_talked( + player_names_in_conversation, nonplayers_in_conversation + ) + + conversation_log = { + 'date': self._clock.now(), + 'Who talked?': who_talked, + 'Event statement': event_statement, + 'Summary': conversation_summary, + 'Full conversation': scene_output, + 'Chain of thought': { + 'Summary': 'Conversation chain of thought', + 'Chain': document.view().text().splitlines(), + }, + 'Scene log': convo_scene.get_history(), + } + + conversation_summary = who_talked + ' ' + conversation_summary + + if self._verbose: + self._log(scene_output) + self._log(conversation_summary) + + self._history.append(conversation_log) + self._memory.add(conversation_summary) diff --git a/concordia/environment/components/direct_effect.py b/concordia/environment/components/direct_effect.py new file mode 100644 index 00000000..654e8e23 --- /dev/null +++ b/concordia/environment/components/direct_effect.py @@ -0,0 +1,156 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Externality for the Game Master, which tracks direct effect on players.""" + +from collections.abc import Callable, Sequence +import concurrent.futures +import datetime + +from concordia.agents import basic_agent +from concordia.associative_memory import associative_memory +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import component +from concordia.utils import helper_functions +import termcolor + + +class DirectEffect(component.Component): + """Tracks direct effect on players. + + A direct effect is an event that directly affects a player in the list of + players. + """ + + def __init__( + self, + players: Sequence[basic_agent.BasicAgent], + clock_now: Callable[[], datetime.datetime], + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + components: Sequence[component.Component] | None = None, + verbose: bool = False, + print_colour: str = 'magenta', + ): + self._players = players + self._verbose = verbose + self._print_colour = print_colour + self._components = components or [] + self._clock_now = clock_now + self._history = [] + self._model = model + self._memory = memory + + def name(self) -> str: + return 'Effect of event on players' + + def _print(self, entry: str): + print(termcolor.colored(entry, self._print_colour), end='') + + def get_player_names(self): + return [player.name for player in self._players] + + def get_history(self): + return self._history.copy() + + def get_last_log(self): + if self._history: + return self._history[-1].copy() + + def update_after_event( + self, + event_statement: str, + ) -> None: + document = interactive_document.InteractiveDocument(self._model) + + for construct in self._components: + document.statement(construct.name() + ': ' + construct.state() + '\n') + + player_names = self.get_player_names() + direct_effect_on_someone = document.yes_no_question( + 'Does the following event directly affect anyone from this ' + + f'list?\n List: {player_names}.\n Event: {event_statement}' + ) + effect_unknown = [] + effect_known = [] + + def _update_player(player): + player_name = player.name + player_doc = helper_functions.filter_copy_as_statement(document) + affected = player_doc.yes_no_question( + f'Does the event affect {player_name} status?' + ) + if affected: + if self._verbose: + self._print(f'\n{player_name} affected, might not known.') + known = player_doc.yes_no_question( + f'Does {player_name} know about the event?' + ) + if known: + if self._verbose: + self._print(f'\n{player_name} known.') + _ = player_doc.open_question( + f'What does {player_name} know about the event?' + ) + how_player_saw_event_first_person = player_doc.open_question( + f"Concisely summarize the event from {player_name}'s " + + 'perspective using third-person limited point of view.' + ) + player.observe(how_player_saw_event_first_person) + if self._verbose: + self._print( + f'\nEffect on {player_name}:' + f' {how_player_saw_event_first_person}' + ) + effect_known.append(how_player_saw_event_first_person) + else: # not known + if self._verbose: + self._print(f'\n{player_name} not known.') + effect_despite_ignorance = player_doc.open_question( + f'How does the event affect {player_name}`s status, despite them' + ' not knowing about it?' + ) + if self._verbose: + self._print( + f'\nUnknown effect on {player_name}: {effect_despite_ignorance}' + ) + effect = f'[effect on {player_name}] {effect_despite_ignorance}' + self._memory.add(effect) + effect_unknown.append(effect) + + # Determined whether externality has happened + # if yes, then propagate the event + if direct_effect_on_someone: + if self._verbose: + self._print( + '\nThe event had a direct affect on one of the players, resolving.' + ) + + with concurrent.futures.ThreadPoolExecutor() as executor: + executor.map(_update_player, self._players) + + update_log = { + 'date': self._clock_now(), + 'Event statement': event_statement, + 'Summary': f'The effect of "{event_statement}"', + 'Known effect': effect_known, + 'Unknown effect': effect_unknown, + 'Chain of thought': { + 'Summary': 'Direct effect chain of thought', + 'Chain': document.view().text().splitlines(), + }, + } + self._history.append(update_log) diff --git a/concordia/environment/components/inventory.py b/concordia/environment/components/inventory.py new file mode 100644 index 00000000..33611a30 --- /dev/null +++ b/concordia/environment/components/inventory.py @@ -0,0 +1,268 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""A component to represent each agent's inventory or possessions.""" + +from collections.abc import Callable, Sequence +import concurrent +import dataclasses +import datetime + +from concordia.associative_memory import associative_memory +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import component +from concordia.utils import helper_functions +import numpy as np +import termcolor + + +_DEFAULT_QUANTITY = 0 + + +@dataclasses.dataclass(frozen=True) +class ItemTypeConfig: + """Class for configuring a type of item to track in an Inventory.""" + + name: str + minimum: float = -np.inf + maximum: float = np.inf + force_integer: bool = False + + +def _many_or_much(is_count_noun: bool) -> str: + """Return 'many' if input is True and 'much' if input is False.""" + if is_count_noun: + return 'many' + else: + return 'much' + + +class Inventory(component.Component): + """A grounded inventory tracking amounts of items in python.""" + + def __init__( + self, + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + item_type_configs: Sequence[ItemTypeConfig], + player_initial_endowments: dict[str, dict[str, float]], + clock_now: Callable[[], datetime.datetime], + financial: bool = False, + name: str = 'Inventory', + verbose: bool = False, + ): + """Initialize a grounded inventory component tracking objects in python. + + Args: + model: a language model + memory: an associative memory + item_type_configs: sequence of item type configurations + player_initial_endowments: dict mapping player name to a dictionary with + item types as keys and initial endownments as values. + clock_now: Function to call to get current time. + financial: If set to True then include special questions to handle the + fact that agents typically say "Alice bought (or sold) X" which is + a different way of speaking than "Alice exchanged X for Y". + name: the name of this component e.g. Possessions, Account, Property, etc + verbose: whether to print the full update chain of thought or not + """ + self._model = model + self._memory = memory + self._player_initial_endowments = player_initial_endowments + self._financial = financial + self._clock_now = clock_now + self._name = name + self._verbose = verbose + + self._item_types = [config.name for config in item_type_configs] + self._item_types_dict = { + config.name: config for config in item_type_configs + } + self._player_names = list(player_initial_endowments.keys()) + + self._inventories = {} + for player_name, endowment in player_initial_endowments.items(): + self._inventories[player_name] = { + item_type: endowment.get(item_type, _DEFAULT_QUANTITY) + for item_type in self._item_types + } + + self._history = [] + self._state = '' + self._partial_states = {name: '' for name in self._player_names} + + # Determine if each item type is a count noun or a mass noun. + self._is_count_noun = {} + + def check_if_count_noun(item_type): + self._is_count_noun[item_type] = helper_functions.is_count_noun( + item_type, self._model + ) + return + + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(self._item_types) + ) as executor: + executor.map(check_if_count_noun, self._item_types) + + # Set the initial state's string representation. + self.update() + + def name(self) -> str: + """Returns the name of this component.""" + return self._name + + def get_last_log(self): + if self._history: + return self._history[-1].copy() + + def get_history(self): + return self._history.copy() + + def _get_player_inventory_str(self, player_name: str) -> str: + return f"{player_name}'s {self._name}: " + str( + self._inventories[player_name] + ) + + def state(self) -> str: + return self._state + + def partial_state( + self, + player_name: str, + ) -> str: + """Return a player-specific view of the component's state.""" + return self._partial_states[player_name] + + def update(self) -> None: + self._state = '\n'.join( + [self._get_player_inventory_str(name) for name in self._player_names] + ) + self._partial_states = { + name: self._get_player_inventory_str(name) + for name in self._player_names + } + + def update_after_event( + self, + event_statement: str, + ) -> None: + chain_of_thought = interactive_document.InteractiveDocument(self._model) + chain_of_thought.statement(f'List of individuals: {self._player_names}') + chain_of_thought.statement(f'List of item types: {self._item_types}') + chain_of_thought.statement(f'Event: {event_statement}') + + inventory_effects = [] + + proceed = chain_of_thought.yes_no_question( + question=( + 'In the above transcript, did any of the listed individuals ' + + 'gain or lose any items on the list of item types? Make sure ' + + 'to take into account items equivalent to the items on the list ' + + 'e.g. if "money" is on the list but the event mentions "gold" ' + + 'then treat "gold" as equivalent to "money" since gold is a type ' + + 'of money.' + ) + ) + if proceed: + if self._financial: + _ = chain_of_thought.open_question( + question=( + 'If the event mentions any financial transaction (buying or ' + 'selling), what price(s) were involved? If no price(s) were ' + 'mentioned then pick logical values for them. If there was no ' + 'transaction then respond with "NA".' + ) + ) + for item_type in self._item_types: + this_item_changed = chain_of_thought.yes_no_question( + question=f'Did any listed individual gain or lose {item_type}?', + ) + if this_item_changed: + players_who_changed_str = chain_of_thought.open_question( + question=( + f'Which individuals gained or lost {item_type}?\n' + + 'Respond with a comma-separated list, for example: \n' + + 'Jacob,Alfred,Patricia' + ) + ) + players_whose_inventory_changed = players_who_changed_str.split(',') + for player in players_whose_inventory_changed: + if player.rstrip(' ') in self._player_names: + prefix = f"[effect on {player}'s {self._name}]" + amount = chain_of_thought.open_question( + question=( + f'How {_many_or_much(self._is_count_noun[item_type])} ' + + f'{item_type} did {player} gain ' + + f'as a result of the event? If they lost {item_type} ' + 'then respond with a negative number.' + ) + ) + try: + amount = float(amount) + except ValueError: + amount = 0.0 + if self._item_types_dict[item_type].force_integer: + if not amount.is_integer(): + inventory_effects.append( + f'{prefix} no effect since amount of {item_type} must ' + + f'be a whole number but {amount} is not.' + ) + continue + old_total = self._inventories[player][item_type] + self._inventories[player][item_type] += amount + maximum = self._item_types_dict[item_type].maximum + minimum = self._item_types_dict[item_type].minimum + self._inventories[player][item_type] = np.min( + [self._inventories[player][item_type], maximum] + ) + self._inventories[player][item_type] = np.max( + [self._inventories[player][item_type], minimum] + ) + # Get amount actually gained/lost once bounds accounted for. + amount = self._inventories[player][item_type] - old_total + effect = '' + if amount > 0: + effect = f'{prefix} gained {amount} {item_type}' + if amount < 0: + absolute_amount = np.abs(amount) + effect = f'{prefix} lost {absolute_amount} {item_type}' + if effect: + if self._is_count_noun[item_type] and np.abs(amount) > 1: + # Add 's' to the end of the noun if it is a count noun. + effect = effect + 's' + inventory_effects.append(effect) + if self._verbose: + print(termcolor.colored(effect, 'yellow')) + + # Update the string representation of all inventories. + self.update() + + if self._verbose: + print(termcolor.colored(chain_of_thought.view().text(), 'yellow')) + print(termcolor.colored(self.state(), 'yellow')) + + update_log = { + 'date': self._clock_now(), + 'Summary': str(self._inventories), + 'Inventories': self.state(), + 'Chain of thought': { + 'Summary': f'{self._name} chain of thought', + 'Chain': chain_of_thought.view().text().splitlines(), + }, + } + self._memory.extend(inventory_effects) + self._history.append(update_log) diff --git a/concordia/environment/components/player_status.py b/concordia/environment/components/player_status.py new file mode 100644 index 00000000..193097b0 --- /dev/null +++ b/concordia/environment/components/player_status.py @@ -0,0 +1,109 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""This construct track the status and location of players.""" + +from collections.abc import Callable, Sequence +import datetime + +from concordia.associative_memory import associative_memory +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import component + + +class PlayerStatus(component.Component): + """Tracks the status of players.""" + + def __init__( + self, + clock_now: Callable[[], datetime.datetime], + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + player_names: Sequence[str], + num_memories_to_retrieve: int = 10, + verbose: bool = False, + ): + self._memory = memory + self._model = model + self._state = '' + self._player_names = player_names + self._num_memories_to_retrieve = num_memories_to_retrieve + self._partial_states = {name: '' for name in self._player_names} + self._verbose = verbose + self._history = [] + self._clock_now = clock_now + + def name(self) -> str: + return 'Status of players' + + def state(self) -> str: + return self._state + + def get_history(self): + return self._history.copy() + + def get_last_log(self): + if self._history: + return self._history[-1].copy() + + def partial_state( + self, + player_name: str, + ) -> str: + """Return a player-specific view of the construct's state.""" + return self._partial_states[player_name] + + def update(self) -> None: + self._state = '\n' + self._partial_states = {name: '' for name in self._player_names} + per_player_prompt = {} + for player_name in self._player_names: + query = f'{player_name}' + mems = ( + '\n'.join( + self._memory.retrieve_associative( + query, k=self._num_memories_to_retrieve, add_time=True) + ) + + '\n' + ) + prompt = interactive_document.InteractiveDocument(self._model) + prompt.statement(f'Events:\n{mems}') + time_now = self._clock_now().strftime('[%d %b %Y %H:%M:%S]') + prompt.statement(f'The current time is: {time_now}\n') + player_loc = ( + prompt.open_question( + 'Given the above events and their time, what is the latest' + f' location of {player_name} and what are they doing?', + answer_prefix=f'{player_name} is ', + ) + + '\n' + ) + per_player_prompt[player_name] = prompt.view().text().splitlines() + if self._verbose: + print(prompt.view().text()) + + # Indent player status outputs. + player_state_string = f' {player_name} is ' + player_loc + self._partial_states[player_name] = player_state_string + self._state = self._state + player_state_string + + update_log = { + 'date': self._clock_now(), + 'state': self._state, + 'partial states': self._partial_states, + 'per player prompts': per_player_prompt, + } + self._history.append(update_log) diff --git a/concordia/environment/components/relevant_events.py b/concordia/environment/components/relevant_events.py new file mode 100644 index 00000000..cbc7ae7f --- /dev/null +++ b/concordia/environment/components/relevant_events.py @@ -0,0 +1,90 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""This component retrieves relevant events from the memory.""" + +from collections.abc import Callable +import datetime + +from concordia.associative_memory import associative_memory +from concordia.language_model import language_model +from concordia.typing import component + + +class RelevantEvents(component.Component): + """Tracks the status of players.""" + + def __init__( + self, + clock_now: Callable[[], datetime.datetime], + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + name: str = 'Relevant events', + num_memories_retrieved_for_update: int = 10, + add_time: bool = True, + use_recency: bool = True, + ): + """Initializes the component. + + Args: + clock_now: Function that returns the current time. + model: Language model. + memory: Associative memory. + name: Name of the component. + num_memories_retrieved_for_update: Number of memories to retrieve when + updating the state. + add_time: Whether to add the time to the retrieved memories. + use_recency: Whether to use recency in memory retrieval or not. + """ + self._memory = memory + self._model = model + self._state = '' + self._history = [] + self._clock_now = clock_now + self._name = name + self._num_memories_retrieved_for_update = num_memories_retrieved_for_update + self._add_time = add_time + self._use_recency = use_recency + + def name(self) -> str: + return self._name + + def state(self) -> str: + return self._state + + def get_history(self): + return self._history.copy() + + def get_last_log(self): + if self._history: + return self._history[-1].copy() + + def update_before_event(self, cause_statement: str) -> None: + mem_retrieved = self._memory.retrieve_associative( + cause_statement, + use_recency=self._use_recency, + add_time=self._add_time, + k=self._num_memories_retrieved_for_update, + ) + + mems = '\n'.join(mem_retrieved) + self._state = mems + + update_log = { + 'date': self._clock_now(), + 'state': self._state, + 'cause_statement': cause_statement, + } + self._history.append(update_log) diff --git a/concordia/environment/components/schedule.py b/concordia/environment/components/schedule.py new file mode 100644 index 00000000..ade853ee --- /dev/null +++ b/concordia/environment/components/schedule.py @@ -0,0 +1,71 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""This construct implements scheduled events.""" + +from collections.abc import Callable +import dataclasses +import datetime +from typing import Optional + +from concordia.typing import component + + +@dataclasses.dataclass(frozen=True) +class EventData: + """Represents an event scheduled to happen at a specific time in the future. + + Attributes: + time: when the event will happen. + description: string to use to condition the game master's narration of the + event. + trigger: a function to call when event happens [optional] + """ + + time: datetime.datetime + description: str + trigger: Optional[Callable[[], None]] = None + + +class Schedule(component.Component): + """A memory construct that represents a schedule of events.""" + + def __init__( + self, + clock, + schedule, + ): + self._clock = clock + self._schedule = schedule + self._state = None + + def name(self) -> str: + return 'Current events' + + def state(self) -> str | None: + return self._state + + def update(self) -> None: + now = self._clock.now() + events = [] + for _, event_data in self._schedule.items(): + if now == event_data.time: + events.append(event_data.description) + if event_data.trigger is not None: + event_data.trigger() + if events: + self._state = '\n'.join(events) + else: + self._state = None diff --git a/concordia/environment/components/time_display.py b/concordia/environment/components/time_display.py new file mode 100644 index 00000000..a471d3f3 --- /dev/null +++ b/concordia/environment/components/time_display.py @@ -0,0 +1,38 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""This component shows current time interval.""" + + +from concordia.typing import clock +from concordia.typing import component + + +class TimeDisplay(component.Component): + """Tracks the status of players.""" + + def __init__( + self, + game_clock: clock.GameClock, + name: str = 'Current time interval', + ): + self._clock = game_clock + self._name = name + + def name(self) -> str: + return self._name + + def state(self) -> str: + return self._clock.current_time_interval_str() diff --git a/concordia/environment/game_master.py b/concordia/environment/game_master.py new file mode 100644 index 00000000..2caf7f9c --- /dev/null +++ b/concordia/environment/game_master.py @@ -0,0 +1,308 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""A Generic Game Master.""" + +from collections.abc import Callable, Sequence +import concurrent.futures +import random + +from concordia.agents import basic_agent +from concordia.associative_memory import associative_memory +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.thought_chains import thought_chains +from concordia.typing import agent as simulacrum_agent +from concordia.typing import clock as game_clock +from concordia.typing import component +from concordia.typing import game_master as simulacrum_game_master +from concordia.typing import metric +import termcolor + + +DEFAULT_THOUGHTS = [ + thought_chains.attempt_to_result, + thought_chains.result_to_who_what_where, +] + + +class GameMaster(simulacrum_game_master.GameMaster): + """A generic game master.""" + + def __init__( + self, + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + clock: game_clock.GameClock, + players: Sequence[basic_agent.BasicAgent], + name: str = 'Game Master', + measurements: Sequence[metric.Metric] | None = None, + update_thought_chain: ( + Sequence[ + Callable[[interactive_document.InteractiveDocument, str], str] + ] + | None + ) = None, + components: Sequence[component.Component] | None = None, + action_spec: simulacrum_agent.ActionSpec | None = None, + randomise_initiative: bool = False, + player_observes_event: bool = True, + players_act_simultaneously: bool = True, + verbose: bool = False, + concurrent_externalities: bool = True, + concurrent_action: bool = False, + log_colour: str = 'red', + ): + """Game master constructor. + + Args: + model: a language model + memory: an associative memory + clock: a clock + players: a sequence of generative agent simulacra which is assumed to + contain only information that players also can access. + name: name of the game master. + measurements: sequence of measurements which look at text and store the + answers to questions in python state variables. + update_thought_chain: chain of thoughts for update from player + components: components to condition on + action_spec: specific action_spec to pass to agents, default is used if + None + randomise_initiative: whether to randomise initiative (who goes first ) + order + player_observes_event: send outcome of the players action back as + observation. Helpful to turn off if using direct_effect externality to + avoid duplicate memories. + players_act_simultaneously: advance time after all players have acted, if + false then advance time after each player acts. + verbose: whether to print debugging information or not. + concurrent_externalities: if true, runs externalities in separate threads + concurrent_action: if true, runs player actions and events in separate + threads + log_colour: colour in which to print logs + """ + self._name = name + self._model = model + self._memory = memory + self._clock = clock + self._players = players + self._log_colour = log_colour + self._measurements = measurements or [] + self._randomise_initiative = randomise_initiative + self._player_observes_event = player_observes_event + self._players_act_simultaneously = players_act_simultaneously + self._action_spec = action_spec or simulacrum_agent.DEFAULT_ACTION_SPEC + self._concurrent_action = concurrent_action + + self._components = {} + for comp in components: + if comp.name() in self._components: + raise ValueError(f'Duplicate component name: {comp.name()}') + else: + self._components[comp.name()] = comp + + self._verbose = verbose + + self._update_from_player_thoughts = update_thought_chain or DEFAULT_THOUGHTS + + self._players_by_name = {player.name: player for player in self._players} + + self._concurrent_externalities = concurrent_externalities + self._log = [] + + self.reset() + + def name(self): + return self._name + + def get_history(self): + return self._log.copy() + + def get_data_frame(self): + return self._memory.get_data_frame() + + def _print(self, entry, colour=None): + print(termcolor.colored(entry, colour or self._log_colour)) + + def reset(self): + self._last_chain = None + self._num_players = len(self._players) + + def get_player_names(self): + return [player.name for player in self._players] + + def update_from_player(self, player_name: str, action_attempt: str): + prompt = interactive_document.InteractiveDocument(self._model) + + with concurrent.futures.ThreadPoolExecutor() as executor: + executor.map( + lambda construct: construct.update_before_event( + f'{player_name}: {action_attempt}' + ), + self._components, + ) + + for comp in self._components.values(): + state_of_component = comp.state() + if state_of_component: + prompt.statement(comp.name() + ': ' + comp.state() + '\n') + + prompt.statement(f"\n{player_name}'s attempted action: {action_attempt}") + + # Produce the event that has happened as the result of the action attempt + prompt, event_statement = thought_chains.run_chain_of_thought( + self._update_from_player_thoughts, action_attempt, prompt + ) + + self._memory.add(event_statement) + + # This gives duplicates if direct_effect-like component is used + if self._player_observes_event: + self._players_by_name[player_name].observe(event_statement) + + if self._verbose: + self._print( + '\nGM context of action and chain of thought:\n' + + prompt.view().text() + ) + + if self._verbose: + self._print(event_statement, 'white') + + update_log = { + 'date': self._clock.now(), + 'Event statement': event_statement, + 'Summary': event_statement, + 'Chain of thought': { + 'Summary': "Game Master's chain of thought", + 'Chain': prompt.view().text().splitlines(), + }, + 'Active player': { + 'Name': player_name, + 'Action attempt': action_attempt, + 'Chain of thought': self._players_by_name[ + player_name + ].get_last_log(), + }, + } + + # Consequences + def get_externality(externality): + return externality.update_after_event(event_statement) + + consequences = [] + if self._concurrent_externalities: + with concurrent.futures.ThreadPoolExecutor() as executor: + for result in executor.map( + get_externality, list(self._components.values())): + if result: + consequences.extend(result) + else: + for externality in self._components.values(): + result = externality.update_after_event(event_statement) + if result: + consequences.extend(result) + + for fact in consequences: + if fact: + if self._verbose: + self._print(fact, 'white') + self._memory.add(fact) # could be multi-threaded, helps with importance + + self._last_chain = prompt + + for externality in self._components.values(): + last_log = externality.get_last_log() + if last_log: + if 'date' in last_log.keys(): + last_log.pop('date') + if 'Event statement' in last_log.keys(): + last_log.pop('Event statement') + + update_log[externality.name()] = last_log + + self._log.append(update_log) + + # MULTI-THREAD + def process_measurements(signal): + return signal.update(event_statement, player_name, prompt) + + with concurrent.futures.ThreadPoolExecutor() as executor: + executor.map(process_measurements, self._measurements) + + return event_statement + + def view_for_player(self, player_name): + """Send observations to a player.""" + for comp in self._components.values(): + state_of_component = comp.partial_state(player_name) + if state_of_component: + self._players_by_name[player_name].observe( + comp.name() + ': ' + state_of_component + ) + + return + + def update_components(self) -> None: + # MULTI THREAD! + with concurrent.futures.ThreadPoolExecutor() as executor: + executor.map( + lambda construct: construct.update(), list(self._components.values())) + + def _step_player(self, player: basic_agent.BasicAgent): + self.update_components() + self.view_for_player(player_name=player.name) + action = player.act(self._action_spec) + + self.update_from_player(action_attempt=action, player_name=player.name) + + def step(self): + """Steps the game. + + At each step players all take a turn 'quasisimultaneously' with regard to + the main game clock, but still in a specific order within the timestep. + This is the same principle as initiative order in dungeons and dragons. + """ + players = list(self._players) + + if self._randomise_initiative: + random.shuffle(players) + + if self._concurrent_action: + with concurrent.futures.ThreadPoolExecutor() as executor: + executor.map(self._step_player, players) + else: + for player in players: + self._step_player(player) + if not self._players_act_simultaneously: + self._clock.advance() + if self._players_act_simultaneously: + self._clock.advance() + + def run_episode(self, max_steps: int = 20) -> list[str]: + for _ in range(max_steps): + self.step() + for comp in self._components.values(): + if comp.terminate_episode(): + return self._memory.retrieve_recent(k=1000, add_time=True) + return self._memory.retrieve_recent(k=1000, add_time=True) + + def add_component(self, comp: component.Component) -> None: + """Add a component to the game master.""" + self._components[comp.name()] = comp + + def remove_component(self, component_name: str) -> None: + """Remove a component from the game master by name.""" + del self._components[component_name] diff --git a/concordia/environment/metrics/__init__.py b/concordia/environment/metrics/__init__.py new file mode 100644 index 00000000..5123c855 --- /dev/null +++ b/concordia/environment/metrics/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Library of metrics for generative game master and agents.""" + +from concordia.environment.metrics import common_sense_morality +from concordia.environment.metrics import goal_achievement +from concordia.environment.metrics import reputation +from concordia.environment.metrics import uncertainty_scale_question diff --git a/concordia/environment/metrics/common_sense_morality.py b/concordia/environment/metrics/common_sense_morality.py new file mode 100644 index 00000000..f453601e --- /dev/null +++ b/concordia/environment/metrics/common_sense_morality.py @@ -0,0 +1,115 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Metrics of ethical conduct.""" + +from collections.abc import Sequence +from typing import Any + +from concordia.agents import basic_agent +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import clock as game_clock +from concordia.typing import metric + + +DEFAULT_SCALE = ['abhorrent', 'wrong', 'neutral', 'right', 'praiseworthy'] + + +class CommonSenseMoralityMetric(metric.Metric): + """Metric of common sense morality of players actions.""" + + def __init__( + self, + model: language_model.LanguageModel, + players: Sequence[basic_agent.BasicAgent], + clock: game_clock.GameClock, + name: str = 'Morality', + scale: Sequence[str] | None = None, + verbose: bool = False, + writer=None, + ): + """Initializes the metric. + + Args: + model: The language model to use. + players: The players to measure. + clock: The clock to use. + name: The name of the metric. + scale: The scale of the metric, uses default if None. + verbose: Whether to print the metric. + writer: The writer to use. + """ + self._model = model + self._name = name + self._state = [] + self._clock = clock + self._verbose = verbose + self._players = players + self._scale = scale or DEFAULT_SCALE + self._writer = writer + + self._player_by_name = {player.name: player for player in players} + + def name( + self, + ) -> str: + """Returns the name of the measurement.""" + return self._name + + def update( + self, + observation: str, + acting_player_name: str, + doc: interactive_document.InteractiveDocument, + ) -> None: + current_player = self._player_by_name[acting_player_name] + + doc = doc.new() + answer = doc.multiple_choice_question( + ( + f'Evaluate if the following activity of {current_player.name} using' + f' common sense morality.\nActivity: {observation}\n' + ), + self._scale, + ) + answer_str = self._scale[answer] + + answer = float(answer) / float(len(self._scale) - 1) + + datum = { + 'time_str': self._clock.now().strftime('%H:%M:%S'), + 'clock_step': self._clock.get_step(), + 'step_metric': len(self._state), + 'value_float': answer, + 'value_str': answer_str, + 'player': acting_player_name, + } + if self._writer is not None: + self._writer.write(datum) + self._writer.flush() + + datum['time'] = self._clock.now() + self._state.append(datum) + + if self._verbose: + print(f'{self._name} of {current_player.name}: {answer_str}') + + def state(self) -> list[dict[str, Any]]: + """Return the current state of all the tracked variables.""" + return self._state.copy() + + def get_scale(self) -> Sequence[str]: + return self._scale diff --git a/concordia/environment/metrics/goal_achievement.py b/concordia/environment/metrics/goal_achievement.py new file mode 100644 index 00000000..d46c4a36 --- /dev/null +++ b/concordia/environment/metrics/goal_achievement.py @@ -0,0 +1,116 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Metrics of goal achievement per player.""" + +from collections.abc import Sequence +from typing import Any + +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import clock as game_clock +from concordia.typing import metric + +DEFAULT_SCALE = [ + 'activity unrelated to the goal', + 'somewhat working towards the goal', + 'working towards the goal', + 'goal achieved', +] + + +class GoalAchievementMetric(metric.Metric): + """Metric of goal achievement per player / goal pair.""" + + def __init__( + self, + model: language_model.LanguageModel, + player_goals: dict[str, str], + clock: game_clock.GameClock, + name: str = 'Goal Achievement', + scale: Sequence[str] | None = None, + verbose: bool = False, + writer=None, + ): + """Initializes the metric. + + Args: + model: Language model to use for the question. + player_goals: Dictionary of player name to player goal. + clock: Clock for logging. + name: Name of the metric. + scale: Scale of the metric, uses default if None. + verbose: Whether to print logs during execution. + writer: Writer to use for logging. + """ + self._model = model + self._name = name + self._state = [] + self._clock = clock + self._verbose = verbose + self._player_goals = player_goals + self._scale = scale or DEFAULT_SCALE + self._writer = writer + + def name( + self, + ) -> str: + """Returns the name of the measurement.""" + return self._name + + def update( + self, + observation: str, + acting_player_name: str, + doc: interactive_document.InteractiveDocument, + ) -> None: + acting_player_goal = self._player_goals[acting_player_name] + doc = doc.new() + answer = doc.multiple_choice_question( + ( + 'Evaluate if the following activity brings' + f' {acting_player_name} closer to their goal' + f' "{acting_player_goal} .\n Activity: {observation}\n' + ), + self._scale, + ) + answer_str = self._scale[answer] + + answer = float(answer) / float(len(self._scale) - 1) + + datum = { + 'time_str': self._clock.now().strftime('%H:%M:%S'), + 'clock_step': self._clock.get_step(), + 'step_metric': len(self._state), + 'value_float': answer, + 'value_str': answer_str, + 'player': acting_player_name, + 'goal': acting_player_goal, + } + if self._writer is not None: + self._writer.write(datum) + self._writer.flush() + datum['time'] = self._clock.now() + + self._state.append(datum) + if self._verbose: + print(f'{self._name} of {acting_player_name}: {answer_str}') + + def state(self) -> list[dict[str, Any]]: + """Return the current state of all the tracked variables.""" + return self._state.copy() + + def get_scale(self) -> Sequence[str]: + return self._scale diff --git a/concordia/environment/metrics/reputation.py b/concordia/environment/metrics/reputation.py new file mode 100644 index 00000000..be33b19f --- /dev/null +++ b/concordia/environment/metrics/reputation.py @@ -0,0 +1,145 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Metrics of player`s reputation among other players.""" + +from collections.abc import Sequence +import concurrent.futures +from typing import Any + +from concordia.agents import basic_agent +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import agent as simulacrum_agent +from concordia.typing import clock as game_clock +from concordia.typing import metric +import numpy as np + +DEFAULT_SCALE = [ + 'very negative', + 'somewhat negative', + 'neutral', + 'somewhat positive', + 'very positive', +] + + +class ReputationMetric(metric.Metric): + """Metric of players reputation among the each other.""" + + def __init__( + self, + model: language_model.LanguageModel, + players: Sequence[basic_agent.BasicAgent], + clock: game_clock.GameClock, + name: str = 'Reputation', + scale: Sequence[str] | None = None, + verbose: bool = False, + writer=None, + question: str = 'What is {opining_player}\'s opinion of {of_player}?', + ): + """Initializes the metric. + + Args: + model: Language model to use for the question. + players: List of players. + clock: Clock for logging. + name: Name of the metric. + scale: Scale of the metric, uses default if None. + verbose: Whether to print logs during execution. + writer: Writer to use for logging. + question: The question to ask players about opinions on other players. + Must have two formatting fields: "{opining_player}" and "{of_player}". + """ + self._model = model + self._name = name + self._state = [] + self._clock = clock + self._verbose = verbose + self._players = players + self._scale = scale or DEFAULT_SCALE + self._writer = writer + self._question = question + + self._player_by_name = {player.name: player for player in players} + + def name( + self, + ) -> str: + """Returns the name of the measurement.""" + return self._name + + def update( + self, + observation: str, + acting_player_name: str, + doc: interactive_document.InteractiveDocument, + ) -> None: + del doc, observation # this metric doesn't use either + + def get_reputation(current_player: basic_agent.BasicAgent) -> None: + if current_player.name == acting_player_name: + return + + question = ( + self._question.format(opining_player=current_player.name, + of_player=acting_player_name) + ) + action_spec = simulacrum_agent.ActionSpec( + call_to_action=question, + output_type='CHOICE', + options=self._scale, + ) + + with current_player.interrogate(): + answer_str = current_player.act(action_spec, memorize=False) + answer = np.where(np.array(self._scale) == answer_str)[0][0] + + answer = float(answer) / float(len(self._scale) - 1) + datum = { + 'time_str': self._clock.now().strftime('%H:%M:%S'), + 'clock_step': self._clock.get_step(), + 'step_metric': len(self._state), + 'value_float': answer, + 'value_str': answer_str, + 'player': acting_player_name, + 'rating_player': current_player.name, + } + if self._writer is not None: + self._writer.write(datum) + self._writer.flush() + + datum['time'] = self._clock.now() + self._state.append(datum) + if self._verbose: + print( + f'{self._name} of {acting_player_name} as viewed by ' + f'{current_player.name}:' + f' {answer_str}' + ) + + return + + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(self._players) + ) as executor: + executor.map(get_reputation, self._players) + + def state(self) -> list[dict[str, Any]]: + """Return the current state of all the tracked variables.""" + return self._state.copy() + + def get_scale(self) -> Sequence[str]: + return self._scale diff --git a/concordia/environment/metrics/uncertainty_scale_question.py b/concordia/environment/metrics/uncertainty_scale_question.py new file mode 100644 index 00000000..d120b693 --- /dev/null +++ b/concordia/environment/metrics/uncertainty_scale_question.py @@ -0,0 +1,114 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Metrics for tracking the answer to a configurable question.""" + +from collections.abc import Sequence +from typing import Any + +from concordia.agents import basic_agent +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import agent as simulacrum_agent +from concordia.typing import clock as game_clock +from concordia.typing import metric +import numpy as np + + +DEFAULT_SCALE = [ + 'Definitively not', + 'Maybe not', + 'Maybe yes', + 'Definitively yes', +] + +DEFAULT_QUESTION = 'Would {agent_name} talk to a stranger?' + + +class Question(metric.Metric): + """Metrics for tracking the answer to a configurable question.""" + + def __init__( + self, + model: language_model.LanguageModel, + players: Sequence[basic_agent.BasicAgent], + clock: game_clock.GameClock, + name: str = 'Question', + question: str | None = None, + scale: Sequence[str] | None = None, + verbose: bool = False, + writer=None, + ): + self._model = model + self._name = name + self._state = [] + self._clock = clock + self._verbose = verbose + self._players = players + self._scale = scale or DEFAULT_SCALE + self._writer = writer + self._question = question or DEFAULT_QUESTION + + self._player_by_name = {player.name: player for player in players} + + def name( + self, + ) -> str: + """Returns the name of the measurement.""" + return self._name + + def update( + self, + observation: str, + acting_player_name: str, + doc: interactive_document.InteractiveDocument, + ) -> None: + del doc, observation # this metric doesn't use either + question = self._question.format(agent_name=acting_player_name) + action_spec = simulacrum_agent.ActionSpec( + call_to_action=question, + output_type='CHOICE', + options=self._scale, + ) + current_player = self._player_by_name[acting_player_name] + + with current_player.interrogate(): + answer_str = current_player.act(action_spec) + answer = np.where(np.array(self._scale) == answer_str)[0][0] + + answer = float(answer) / float(len(self._scale) - 1) + datum = { + 'time_str': self._clock.now().strftime('%H:%M:%S'), + 'clock_step': self._clock.get_step(), + 'step_metric': len(self._state), + 'value_float': answer, + 'value_str': answer_str, + 'player': acting_player_name, + } + if self._writer is not None: + self._writer.write(datum) + self._writer.flush() + + datum['time'] = self._clock.now() + self._state.append(datum) + if self._verbose: + print(f'{question}\n{acting_player_name}: {answer_str}') + + def state(self) -> list[dict[str, Any]]: + """Return the current state of all the tracked variables.""" + return self._state.copy() + + def get_scale(self) -> Sequence[str]: + return self._scale diff --git a/concordia/environment/scenes/__init__.py b/concordia/environment/scenes/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/concordia/environment/scenes/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/concordia/environment/scenes/conversation.py b/concordia/environment/scenes/conversation.py new file mode 100644 index 00000000..01218d92 --- /dev/null +++ b/concordia/environment/scenes/conversation.py @@ -0,0 +1,168 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""The conversation scene. + +The conversation scene configures of the game master that runs a +conversation between players, while conditining them on the full history of the +conversation at each step through the ConversationTracker component. +""" + +from collections.abc import Sequence + +from concordia.agents import basic_agent +from concordia.associative_memory import blank_memories +from concordia.clocks import game_clock +from concordia.document import interactive_document +from concordia.environment import game_master as game_master_lib +from concordia.language_model import language_model +from concordia.thought_chains import thought_chains +from concordia.typing import agent as simulacrum_agent +from concordia.typing import component +from concordia.typing import metric +import termcolor + + +class ConversationTracker(component.Component): + """This component accumulates history of a conversation scene in its state.""" + + def __init__( + self, + model: language_model.LanguageModel, + players: Sequence[basic_agent.BasicAgent], + premis: str = '', + verbose: bool = False, + log_colour: str = 'red', + ): + """This component accumulates history of a conversation scene in its state. + + Args: + model: a language model + players: players participating + premis: any extra text to be added on top of the conversation (say, + circumstances of it) + verbose: whether or not to print intermediate reasoning steps + log_colour: colour for logging + """ + self._model = model + self._state = premis + self._log_colour = log_colour + self._players = players + + self._verbose = verbose + + def name(self) -> str: + return 'Conversation history' + + def state(self): + return self._state + + def terminate_episode(self) -> bool: + chain_of_thought = interactive_document.InteractiveDocument(self._model) + chain_of_thought.statement(f'Conversation:\n{self._state}\n') + + did_conclude = chain_of_thought.multiple_choice_question( + 'Is the conversation above over and not going to continue?', + answers=['No', 'Yes'], + ) + if self._verbose: + self._log(chain_of_thought.view().text()) + + return did_conclude == 1 + + def _log(self, entry: str): + print(termcolor.colored(entry, self._log_colour), end='') + + def update_after_event(self, event_statement: str): + # The event_statement contains the last utterence in the conversation + self._state += '\n' + event_statement + if self._verbose: + self._log(f'Current state of converstion: {self._state}') + for player in self._players: + player.observe(event_statement) + + def update(self): + return self._state + + +def make_conversation_game_master( + players: Sequence[basic_agent.BasicAgent], + clock: game_clock.MultiIntervalClock, + model: language_model.LanguageModel, + memory_factory: blank_memories.MemoryFactory, + measurements: Sequence[metric.Metric] | None, + name: str = 'Conversation scene', + premise: str = '', +): + """Creates a game master that runs a conversation between players. + + Args: + players: players participating + clock: a clock + model: a language model + memory_factory: a memory factory + measurements: measurements for the game master to use + name: the name of the game master + premise: any extra text to be added on top of the conversation (say, + circumstances of it) + + Returns: + a game master + """ + + action_spec = simulacrum_agent.ActionSpec( + simulacrum_agent.DEFAULT_CALL_TO_SPEECH, + 'FREE', + tag='speech', + ) + + agent_names = [player.name for player in players] + + is_are = 'are' if len(agent_names) > 1 else 'is' + convo = f'{", ".join(agent_names)} {is_are} in conversation' + if premise: + convo = ( + f'{premise}\nAs a result {convo}.\nHere is the conversation from the' + ' beginning:' + ) + + conversation_tracker = ConversationTracker( + model=model, + players=players, + premis=convo, + verbose=True, + log_colour='red', + ) + + for player in players: + player.observe(convo) + + memory = memory_factory.make_blank_memory() + game_master = game_master_lib.GameMaster( + model=model, + memory=memory, + clock=clock, + name=name, + players=players, + measurements=measurements, + components=[conversation_tracker], + action_spec=action_spec, + update_thought_chain=[thought_chains.identity], + randomise_initiative=False, + player_observes_event=False, + concurrent_externalities=False, + verbose=True, + ) + return game_master diff --git a/concordia/examples/phone/__init__.py b/concordia/examples/phone/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/concordia/examples/phone/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/concordia/examples/phone/calendar.ipynb b/concordia/examples/phone/calendar.ipynb new file mode 100644 index 00000000..5f22d764 --- /dev/null +++ b/concordia/examples/phone/calendar.ipynb @@ -0,0 +1,686 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "qFEJE9lTLk0y" + }, + "source": [ + "```\n", + "Copyright 2023 DeepMind Technologies Limited.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zWgEkOAO9OVz" + }, + "source": [ + "# Calendar Example\n", + "\n", + "An illustrative social simulation with 2 players which simulates phone interactions. The two players, Alice and Bob, have a smartphone with a Calendar app. Alice's goal is to setup a meeting with Bob using the Calendar app on her phone, taking Bob's scheulde into account when selecting the date/time." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "J2TwJrZ08wXz" + }, + "source": [ + "## Init and import" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-qLG5ExLqpWa" + }, + "outputs": [], + "source": [ + "# @title Imports\n", + "\n", + "import concurrent.futures\n", + "import datetime\n", + "import random\n", + "\n", + "from IPython import display\n", + "\n", + "from concordia.agents import basic_agent\n", + "from concordia.agents import components\n", + "from concordia.associative_memory import associative_memory\n", + "from concordia.associative_memory import blank_memories\n", + "from concordia.associative_memory import embedder_st5\n", + "from concordia.associative_memory import formative_memories\n", + "from concordia.associative_memory import importance_function\n", + "from concordia.clocks import game_clock\n", + "from concordia.environment import components as gm_components\n", + "from concordia.environment import game_master\n", + "from concordia.language_model import sax_model\n", + "from concordia.utils import html as html_lib\n", + "\n", + "from concordia.examples.phone.components import apps\n", + "from concordia.examples.phone.components import triggering" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "I3OtW8flCJSC" + }, + "outputs": [], + "source": [ + "#@title Setup sentence encoder\n", + "embedder = embedder_st5.EmbedderST5()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cugwvFIKv5AS" + }, + "outputs": [], + "source": [ + "# @title SAX Language Model\n", + "\n", + "# Add path to your SAX server here:\n", + "SAX_PATH = '' # @param {type:\"string\"}\n", + "DEFAULT_MAX_TOKENS = 300 # @param {type: 'integer'}\n", + "DEFAULT_TIMEOUT_SECONDS = 60 # @param {type: 'number'}\n", + "\n", + "model = sax_model.SAXLanguageModel(SAX_PATH)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z9HYjZgyakc_" + }, + "source": [ + "## Configuring the genereric knowledge of the players and the game master (GM)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "b8vWoQ6by51N" + }, + "outputs": [], + "source": [ + "# @title Generic memories are memories that all players and GM share.\n", + "\n", + "shared_memories = [\n", + " 'There is a hamlet named Riverbend.',\n", + " 'Riverbend is an idyllic rural town.',\n", + " 'The river Solripple runs through the village of Riverbend.',\n", + " 'The Solripple is a mighty river.',\n", + " 'Riverbend has a temperate climate.',\n", + " 'Riverbend has a main street.',\n", + " 'There is a guitar store on Main street Riverbend.',\n", + " 'There is a grocery store on Main street Riverbend.',\n", + " 'There is a school on Main street Riverbend.',\n", + " 'There is a library on Main street Riverbend.',\n", + " 'Riverbend has only one pub.',\n", + " 'There is a pub on Main street Riverbend called The Sundrop Saloon.',\n", + " 'Town hall meetings often take place at The Sundrop Saloon.',\n", + " 'Riverbend does not have a park',\n", + " 'The main crop grown on the farms near Riverbend is alfalfa.',\n", + " 'Farms near Riverbend depend on water from the Solripple river.',\n", + " 'There is no need to register in advance to be on the ballot.',\n", + "]\n", + "\n", + "# The generic context will be used for the NPC context. It reflects general\n", + "# knowledge and is possessed by all characters.\n", + "shared_context = model.sample_text(\n", + " 'Summarize the following passage in a concise and insightful fashion:\\n'\n", + " + '\\n'.join(shared_memories)\n", + " + '\\n'\n", + " + 'Summary:'\n", + ")\n", + "print(shared_context)\n", + "importance_model = importance_function.ConstantImportanceModel()\n", + "importance_model_gm = importance_function.ConstantImportanceModel()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TeVYseoD2WYa" + }, + "outputs": [], + "source": [ + "#@title Make the clock\n", + "SETUP_TIME = datetime.datetime(hour=8, year=2024, month=9, day=1)\n", + "\n", + "START_TIME = datetime.datetime(hour=9, year=2024, month=10, day=1)\n", + "clock = game_clock.MultiIntervalClock(\n", + " start=SETUP_TIME,\n", + " step_sizes=[datetime.timedelta(hours=1), datetime.timedelta(seconds=10)])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YBCXUQ8sayzj" + }, + "source": [ + "## Functions to build the players" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OvPT0bnltrAN" + }, + "outputs": [], + "source": [ + "blank_memory_factory = blank_memories.MemoryFactory(\n", + " model=model,\n", + " embedder=embedder,\n", + " importance=importance_model.importance,\n", + " clock_now=clock.now,\n", + ")\n", + "\n", + "formative_memory_factory = formative_memories.FormativeMemoryFactory(\n", + " model=model,\n", + " shared_memories=shared_memories,\n", + " blank_memory_factory_call=blank_memory_factory.make_blank_memory,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "As465DbcsAwZ" + }, + "outputs": [], + "source": [ + "def build_agent(agent_config):\n", + "\n", + " mem = formative_memory_factory.make_memories(agent_config)\n", + "\n", + " # Build the player.\n", + "\n", + " time = components.report_state.ReportState(\n", + " name='Current time',\n", + " get_state=clock.current_time_interval_str)\n", + "\n", + " somatic_state = components.somatic_state.SomaticState(\n", + " model, mem, agent_config.name, clock\n", + " )\n", + " identity = components.identity.SimIdentity(model, mem, agent_config.name)\n", + " goal_component = components.constant.ConstantConstruct(state=agent_config.goal)\n", + " plan = components.plan.SimPlan(\n", + " model,\n", + " mem,\n", + " agent_config.name,\n", + " components=[identity],\n", + " goal=goal_component,\n", + " verbose=False,\n", + " )\n", + " current_obs = components.observation.Observation(agent_config.name, mem)\n", + " summary_obs = components.observation.ObservationSummary(\n", + " model=model,\n", + " agent_name=agent_config.name,\n", + " components=[identity],\n", + " )\n", + " agent = basic_agent.BasicAgent(\n", + " model,\n", + " mem,\n", + " agent_name=agent_config.name,\n", + " clock=clock,\n", + " verbose=True,\n", + " components=[identity, plan, somatic_state, summary_obs, current_obs, time],\n", + " )\n", + "\n", + " agent.update()\n", + "\n", + " return agent" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qt8CK2mMbD7q" + }, + "source": [ + "## Configure and build the players" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "o1WDljMsuCTh" + }, + "outputs": [], + "source": [ + "NUM_PLAYERS = 2\n", + "victim = 'Alice'\n", + "\n", + "def make_random_big_five()-\u003estr:\n", + " return str({\n", + " 'extraversion': random.randint(1, 10),\n", + " 'neuroticism': random.randint(1, 10),\n", + " 'openness': random.randint(1, 10),\n", + " 'conscientiousness': random.randint(1, 10),\n", + " 'agreeableness': random.randint(1, 10),\n", + " })\n", + "\n", + "scenario_premise = [\n", + "\n", + " (\n", + " 'Alice, Bob, Charlie and Dorothy are at the Sundrop Saloon. There '\n", + " + 'is a snow storm and they have to wait it out inside.'\n", + " ),\n", + "]\n", + "player_configs = [\n", + " formative_memories.AgentConfig(\n", + " name='Alice',\n", + " gender='female',\n", + " goal='Setup a meeting with Bob for two weeks from today using her smartphone.',\n", + " context=f'{shared_context}\\nAlice grew up in Riverbend.',\n", + " traits = make_random_big_five()\n", + " ),\n", + " formative_memories.AgentConfig(\n", + " name='Bob',\n", + " gender='male',\n", + " goal='Just chill and enjoy life.',\n", + " context=f'{shared_context}\\nBob grew up in Riverbend.',\n", + " traits = make_random_big_five()\n", + " ),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CBGAqV7-uP2i" + }, + "outputs": [], + "source": [ + "\n", + "player_configs = player_configs[:NUM_PLAYERS]\n", + "player_goals = {player_config.name: player_config.goal for player_config in player_configs}\n", + "players = []\n", + "\n", + "with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_PLAYERS) as pool:\n", + " for agent in pool.map(build_agent, player_configs[:NUM_PLAYERS]):\n", + " players.append(agent)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2vt8ggYUrW8M" + }, + "source": [ + "## Build the GM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "siwglxrc6z2j" + }, + "outputs": [], + "source": [ + "game_master_instructions = (\n", + " 'This is a social science experiment. It is structured as a '\n", + " 'tabletop roleplaying game (like dungeons and dragons). You are the '\n", + " 'game master. You will describe the current situation to the '\n", + " 'participants in the experiment and then on the basis of what you '\n", + " 'tell them they will suggest actions for the character they control. '\n", + " 'Aside from you, each other participant controls just one character. '\n", + " 'You are the game master so you may control any non-player '\n", + " 'character. You will track the state of the world and keep it '\n", + " 'consistent as time passes in the simulation and the participants '\n", + " 'take actions and change things in their world. Remember that this '\n", + " 'is a serious social science experiment. It is not just a game. It '\n", + " 'need not be fun for the participants. Always use third-person '\n", + " 'limited perspective, even when speaking directly to the participants.'\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3W65kHOKQwrv" + }, + "outputs": [], + "source": [ + "game_master_memory = associative_memory.AssociativeMemory(\n", + " embedder, importance_model_gm.importance, clock=clock.now)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bGNY_D7FID4I" + }, + "outputs": [], + "source": [ + "for player in players:\n", + " game_master_memory.add(f'{player.name} is at their private home.')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-cxivChc633z" + }, + "outputs": [], + "source": [ + "# @title Create components and externalities\n", + "citizen_names = [player.name for player in players]\n", + "player_names = [player.name for player in players]\n", + "\n", + "instructions_construct = components.constant.ConstantConstruct(game_master_instructions, 'Instructions')\n", + "facts_on_village = components.constant.ConstantConstruct(' '.join(shared_memories), 'General knowledge of Riverbend')\n", + "player_status = gm_components.player_status.PlayerStatus(clock.now, model, game_master_memory, player_names)\n", + "\n", + "relevant_events = gm_components.relevant_events.RelevantEvents(clock.now, model, game_master_memory)\n", + "time_display = gm_components.time_display.TimeDisplay(clock)\n", + "\n", + "\n", + "direct_effect_externality = gm_components.direct_effect.DirectEffect(\n", + " players, memory=game_master_memory, model=model, clock_now=clock.now, verbose=False, components=[player_status]\n", + ")\n", + "\n", + "toy_calendar = apps.ToyCalendar()\n", + "phones = [apps.Phone('Alice', apps=[toy_calendar]), apps.Phone('Bob', apps=[toy_calendar])]\n", + "phone_triggering = triggering.SceneTriggeringComponent(players, phones, model, memory=game_master_memory, clock=clock, memory_factory=blank_memory_factory)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "d_R2BVNOsAwa" + }, + "outputs": [], + "source": [ + "# @title Create the game master object\n", + "env = game_master.GameMaster(\n", + " model=model,\n", + " memory=game_master_memory,\n", + " clock=clock,\n", + " players=players,\n", + " components=[\n", + " instructions_construct,\n", + " facts_on_village,\n", + " player_status,\n", + " direct_effect_externality,\n", + " relevant_events,\n", + " time_display,\n", + " phone_triggering,\n", + " ],\n", + " randomise_initiative=True,\n", + " player_observes_event=False,\n", + " verbose=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d2u0bQ1MSCGd" + }, + "source": [ + "## The RUN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hdTRDaxEZZnN" + }, + "outputs": [], + "source": [ + "clock.set(START_TIME)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9IggLF1aH_hF" + }, + "outputs": [], + "source": [ + "for player in players:\n", + " player.observe( f'{player.name} is at home, they have just woken up.')\n", + "\n", + "with concurrent.futures.ThreadPoolExecutor(max_workers=len(players)) as pool:\n", + " for player in players:\n", + " pool.submit(player.update())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "2Bt87stq76gF" + }, + "outputs": [], + "source": [ + "# @title Expect about 2-3 minutes per step.\n", + "episode_length = 12 # @param {type: 'integer'}\n", + "for _ in range(episode_length):\n", + " env.step()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DnwvpvQ4bnFs" + }, + "source": [ + "## Summary and analysis of the episode" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "j71OiuPot5UV" + }, + "source": [ + "## Save results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "O4jp0xGXvOAJ" + }, + "outputs": [], + "source": [ + "# @title Summarize the entire story\n", + "all_gm_memories = env._memory.retrieve_recent(k=10000, add_time=True)\n", + "\n", + "detailed_story = '\\n'.join(all_gm_memories)\n", + "print('len(detailed_story): ', len(detailed_story))\n", + "# print(detailed_story)\n", + "\n", + "episode_summary = model.sample_text(\n", + " f'Sequence of events:\\n{detailed_story}'+\n", + " '\\nNarratively summarize the above temporally ordered ' +\n", + " 'sequence of events. Write it as a news report. Summary:\\n',\n", + " max_characters=8000, max_tokens=8000, terminators=())\n", + "print(episode_summary)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ALG987t-6j-V" + }, + "outputs": [], + "source": [ + "# @title Summarise the perspective of each player\n", + "player_logs = []\n", + "player_log_names = []\n", + "for player in players:\n", + " name = player.name\n", + " detailed_story = '\\n'.join(player._memory.retrieve_recent(k=1000, add_time=True))\n", + " summary = ''\n", + " summary = model.sample_text(\n", + " f'Sequence of events that happened to {name}:\\n{detailed_story}'\n", + " '\\nWrite a short story that summarises these events.\\n'\n", + " ,\n", + " max_characters=8000, max_tokens=8000, terminators=())\n", + "\n", + " all_player_mem = player._memory.retrieve_recent(k=1000, add_time=True)\n", + " all_player_mem = ['Summary:', summary, 'Memories:'] + all_player_mem\n", + " player_html = html_lib.PythonObjectToHTMLConverter(all_player_mem).convert()\n", + " player_logs.append(player_html)\n", + " player_log_names.append(f'{name}')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UmPOvjVxddye" + }, + "source": [ + "#Build and display HTML log of the experiment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JyEoGgI05xI0" + }, + "outputs": [], + "source": [ + "history_sources = [env, direct_effect_externality]\n", + "histories_html = [html_lib.PythonObjectToHTMLConverter(history.get_history()).convert() for history in history_sources]\n", + "histories_names = [history.name() for history in history_sources]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XNJuo4Dwt5Ui" + }, + "outputs": [], + "source": [ + "gm_mem_html = html_lib.PythonObjectToHTMLConverter(all_gm_memories).convert()\n", + "\n", + "tabbed_html = html_lib.combine_html_pages(\n", + " histories_html + [gm_mem_html] + player_logs,\n", + " histories_names + ['GM'] + player_log_names,\n", + " summary=episode_summary,\n", + " title='Calendar experiment',\n", + ")\n", + "\n", + "tabbed_html = html_lib.finalise_html(tabbed_html)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pumxvmrzANOq" + }, + "outputs": [], + "source": [ + "display.HTML(tabbed_html)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HX-M9Im_dneG" + }, + "source": [ + "#Interact with a specific player" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ESJ1l7_Kt5Uj" + }, + "outputs": [], + "source": [ + "sim_to_interact = 'Alice' # @param ['Alice', 'Bob','Charlie', 'Dorothy', 'Ellen'] {type:\"string\"}\n", + "user_identity = 'a close friend' # @param {type:\"string\"}\n", + "interaction_premise = f'{sim_to_interact} is talking to {user_identity}\\n' # @param {type:\"string\"}\n", + "\n", + "player_names = [player.name for player in players]\n", + "player_by_name = {player.name: player for player in players}\n", + "selected_player = player_by_name[sim_to_interact]\n", + "interrogation = interaction_premise" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5Q1cYflLt5Uj" + }, + "outputs": [], + "source": [ + "utterence_from_user = 'Did you schedule a meeting with Bob?' # @param {type:\"string\"}\n", + "\n", + "interrogation += f'{user_identity}: {utterence_from_user}'\n", + "player_says = selected_player.say(interrogation)\n", + "interrogation += f'\\n{sim_to_interact}: {player_says}\\n'\n", + "print(interrogation)" + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "//learning/grp/tools/ml_python:ml_notebook", + "kind": "private" + }, + "private_outputs": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/concordia/examples/phone/components/__init__.py b/concordia/examples/phone/components/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/concordia/examples/phone/components/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/concordia/examples/phone/components/apps.py b/concordia/examples/phone/components/apps.py new file mode 100644 index 00000000..45d2d7b0 --- /dev/null +++ b/concordia/examples/phone/components/apps.py @@ -0,0 +1,264 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Classes for implementing virtual apps simulation.""" + +import abc +from collections.abc import Sequence +import dataclasses +import datetime +import inspect +import re +import textwrap +import typing +from typing import Any + +import docstring_parser + +_DATE_FORMAT = '%Y-%m-%d %H:%M' + +_ARGUMENT_REGEX = re.compile(r'(?P\w+):\s*(?P[^\n]+)') + +_ARGUMENT_PARSERS = { + 'datetime.datetime': lambda date: datetime.datetime.strptime( + date, _DATE_FORMAT + ), + 'str': str, + 'int': int, +} + +_ACTION_PROPERTY = '__app_action__' + + +def app_action(method): + """A decorator that marks PhoneApp methods as callable actions.""" + method.__app_action__ = True + return method + + +class ActionArgumentError(Exception): + """An error that is raised when argument parsing fails.""" + + +@dataclasses.dataclass(frozen=True) +class Parameter: + """Represents a parameter that can be passed to an action.""" + + name: str + kind: type[Any] + description: str | None + + def full_description(self): + return f"{self.name}: {self.description or ''}, type: {self.kind}" + + def value_from_text(self, text: str): + origin = typing.get_origin(self.kind) + if not origin: + return self._parse_single_argument(text) + else: + if origin != list: + raise RuntimeError(f'Unsupported argument type {origin}') + return self._parse_list_argument(text) + + def _parse_single_argument(self, text): + parser = _ARGUMENT_PARSERS.get(self.kind, self.kind) + return parser(text) + + def _parse_list_argument(self, text: str): + arg = typing.get_args(self.kind) + parser = _ARGUMENT_PARSERS.get(arg, arg) + return [parser(e) for e in text.split(',')] + + @classmethod + def create( + cls, parameter: inspect.Parameter, docstring: docstring_parser.Docstring + ): + """Create a Parameter from a method docstring and inspect.Parameter.""" + description = next( + ( + p.description + for p in docstring.params + if p.arg_name == parameter.name + ), + None, + ) + return cls(parameter.name, parameter.annotation, description) + + +@dataclasses.dataclass(frozen=True) +class ActionDescriptor: + """Represents an action that can be invoked on a PhoneApp.""" + + name: str + description: str + parameters: Sequence[Parameter] + docstring: dataclasses.InitVar[docstring_parser.Docstring] + + def __post_init__(self, docstring: docstring_parser.Docstring): + pass + + def instructions(self): + return ( + f'The {self.name} action expects the following parameters:\n' + + '\n'.join(p.full_description() for p in self.parameters) + + textwrap.dedent(""" + All parameters must be provided, each in its own line, for example: + param1: value1 + param2: value2 + """) + ) + + @classmethod + def from_method(cls, method): + doc = docstring_parser.parse(method.__doc__) + description = f"{doc.short_description}\n{doc.long_description or ''}" + parameters = inspect.signature(method).parameters.items() + method_parameters = [ + Parameter.create(p, doc) for name, p in parameters if name != 'self' + ] + return cls( + name=method.__name__, + description=description, + parameters=method_parameters, + docstring=doc, + ) + + +class PhoneApp(metaclass=abc.ABCMeta): + """Base class for apps that concordia can interact with using plain English. + + Extend this class and decorated any method that should be callable from the + simulation with @app_action. + """ + + @abc.abstractmethod + def name(self) -> str: + """Returns the name of the app.""" + raise NotImplementedError + + @abc.abstractmethod + def description(self) -> str: + """Returns a description of the app.""" + raise NotImplementedError + + def actions(self) -> Sequence[ActionDescriptor]: + """Returns this app's callable actions.""" + methods = inspect.getmembers(self, predicate=inspect.ismethod) + return [ + ActionDescriptor.from_method(m) + for _, m in methods + if hasattr(m, _ACTION_PROPERTY) + ] + + def full_description(self): + """Returns a description of the app and all the actions it supports.""" + return textwrap.dedent(f"""\ + {self.name()}: {self.description()} + The app supports the following actions: + """) + '\n'.join(f'{a.name}: {a.description}' for a in self.actions()) + + def invoke_action(self, action: ActionDescriptor, args_text: str) -> str: + r"""Invokes the action on this app instance with the given arguments. + + Args: + action: The action to invoke. + args_text: The arguments to pass to the action, each in its own line with + a colon separating the parameter name from the value, for example: + 'param1: value1\nparam2: value2' + + Returns: + Textual description of the result of invoking the action. + + Raises: + ActionArgumentError: If any of the arguments expected by the action are + missing. + """ + args = _parse_argument_text(args_text) + for p in action.parameters: + if p.name not in args: + raise ActionArgumentError(f'Parameter {p.name} not provided.') + args[p.name] = p.value_from_text(args[p.name]) + + return getattr(self, action.name)(**args) + + +@dataclasses.dataclass(frozen=True) +class Phone: + """Represent a player's phone.""" + + player_name: str + apps: Sequence[PhoneApp] + + def description(self): + return textwrap.dedent(f"""\ + {self.player_name} has a smartphone. + {self.player_name} uses their phone frequently to achieve their daily goals. + {self.player_name}'s phone has the following apps available: + {', '.join(self.app_names())}." + """) + + def app_names(self): + return [a.name() for a in self.apps] + + +# Parse multiline argument text to a text dictionary: +# 'param1: value1\n param2: value2' is parsed to: +# {'param1': 'value1', 'param2': 'value2'} +def _parse_argument_text(args_text: str) -> dict[str, str]: + matches = _ARGUMENT_REGEX.finditer(args_text) + return {m.group('param'): m.group('value') for m in matches} + + +@dataclasses.dataclass(frozen=True, slots=True) +class _Meeting: + time: str + participant: str + title: str + + +class ToyCalendar(PhoneApp): + """A toy calendar app.""" + + def __init__(self): + self._meetings = [] + + def name(self): + return 'Calendar' + + def description(self): + return 'Lets you schedule meetings with other people.' + + @app_action + def add_meeting(self, time: str, participant: str, title: str): + """Add a meeting to the calendar. + + This action schedule a meeting with the participant + and sends them a notification about the meeting. + + Args: + time: The time of the meeting, e.g., tomorrow, in two weeks. + participant: The name of the participant. + title: The title of the meeting, e.g., Alice / John 1:1. + + Returns: + A description of the added meeting. + Raises: + ActionArgumentError: If the format of any of the arguments is invalid. + """ + meeting = _Meeting(time=time, participant=participant, title=title) + self._meetings.append(meeting) + return ( + f'A meeting with {meeting.participant} was scheduled at {meeting.time}.' + ) diff --git a/concordia/examples/phone/components/logging.py b/concordia/examples/phone/components/logging.py new file mode 100644 index 00000000..9e2e41a3 --- /dev/null +++ b/concordia/examples/phone/components/logging.py @@ -0,0 +1,37 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Logger.""" + +import termcolor + + +class Logger: + """Utility for logs messages depending on verbosity.""" + + def __init__(self, color: str = 'magenta', verbose=False, semi_verbose=True): + self._color = color + self._verbose = verbose + self._semi_verbose = semi_verbose + + def verbose(self, entry: str): + if self._verbose: + self._log(entry) + + def semi_verbose(self, entry: str): + if self._semi_verbose: + self._log(entry) + + def _log(self, entry: str): + print(termcolor.colored(entry, self._color)) diff --git a/concordia/examples/phone/components/scene.py b/concordia/examples/phone/components/scene.py new file mode 100644 index 00000000..63e14dbb --- /dev/null +++ b/concordia/examples/phone/components/scene.py @@ -0,0 +1,136 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A GameMaster that simulates a player's interaction with their phone.""" + +import textwrap + +from concordia.agents import basic_agent +from concordia.associative_memory import blank_memories +from concordia.clocks import game_clock +from concordia.document import interactive_document +from concordia.environment import game_master as game_master_lib +from concordia.examples.phone.components import apps +from concordia.examples.phone.components import logging +from concordia.language_model import language_model +from concordia.thought_chains import thought_chains +from concordia.typing import agent +from concordia.typing import component + + +_PHONE_CALL_TO_ACTION = textwrap.dedent("""\ + What actions would {agent_name} perform with their phone + now to best achieve their goal? + Consider their plan, but deviate from it if necessary. + Give a specific activity that can be performed using a + single app on the phone. For example, {agent_name} uses + the Chat app to send a message to George saying 'hi, what's up?". + """) + +_PHONE_ACTION_SPEC = agent.ActionSpec( + _PHONE_CALL_TO_ACTION, 'FREE', tag='phone' +) + + +def build( + player: basic_agent.BasicAgent, + phone: apps.Phone, + clock: game_clock.MultiIntervalClock, + model: language_model.LanguageModel, + memory_factory: blank_memories.MemoryFactory, +) -> game_master_lib.GameMaster: + """Builds a GameMaster that simulates a player's interaction with their phone. + + Args: + player: The player who is interacting with the phone. + phone: The player's phone. + clock: A clock. + model: A language model. + memory_factory: A memory factory for creating the GM's memory. + + Returns: + """ + memory = memory_factory.make_blank_memory() + phone_component = _PhoneComponent(model, player, phone) + return game_master_lib.GameMaster( + model=model, + memory=memory, + clock=clock, + name='PhoneGameMaster', + players=(player,), + components=(phone_component,), + action_spec=_PHONE_ACTION_SPEC, + update_thought_chain=(thought_chains.identity,), + player_observes_event=False, + ) + + +class _PhoneComponent(component.Component): + """Parses the player's actions and invokes them on phone apps.""" + + def __init__( + self, + model: language_model.LanguageModel, + player: basic_agent.BasicAgent, + phone: apps.Phone, + log_color: str = 'red', + verbose: bool = False, + semi_verbose: bool = True, + ): + self._model = model + self._player = player + self._phone = phone + self._logger = logging.Logger(log_color, verbose, semi_verbose) + self._state = '' + + def name(self) -> str: + return 'PhoneComponent' + + def terminate_episode(self) -> bool: + chain_of_thought = interactive_document.InteractiveDocument(self._model) + chain_of_thought.statement(f'Interaction with phone:\n{self._state}') + + did_conclude = chain_of_thought.yes_no_question( + 'Is the user finished using their phone?' + ) + return did_conclude + + def update_after_event(self, event_statement: str): + self._state += '\n' + event_statement + chain_of_thought = interactive_document.InteractiveDocument(self._model) + chain_of_thought.statement(event_statement) + chain_of_thought.statement(self._phone.description()) + app_index = chain_of_thought.multiple_choice_question( + 'In the above transcript, what app did the user use?', + answers=self._phone.app_names(), + ) + + app = self._phone.apps[app_index] + action_names = [a.name for a in app.actions()] + chain_of_thought.statement(app.description()) + action_index = chain_of_thought.multiple_choice_question( + 'In the above transcript, what action did the user perform?', + answers=action_names, + ) + + action = app.actions()[action_index] + + try: + argument_text = chain_of_thought.open_question( + action.instructions(), terminators=[] + ) + result = app.invoke_action(action, argument_text) + return [result] + except apps.ActionArgumentError: + return [] diff --git a/concordia/examples/phone/components/triggering.py b/concordia/examples/phone/components/triggering.py new file mode 100644 index 00000000..c90efb4c --- /dev/null +++ b/concordia/examples/phone/components/triggering.py @@ -0,0 +1,130 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""A component that runs the phone scene when a phone action is detected.""" + +from collections.abc import Sequence + +from concordia.agents import basic_agent +from concordia.associative_memory import associative_memory +from concordia.associative_memory import blank_memories +from concordia.clocks import game_clock +from concordia.document import interactive_document +from concordia.examples.phone.components import apps +from concordia.examples.phone.components import logging +from concordia.examples.phone.components import scene +from concordia.language_model import language_model +from concordia.typing import component +from concordia.utils import helper_functions + + +class SceneTriggeringComponent(component.Component): + """Runs the phone scene when a phone action is detected.""" + + def __init__( + self, + players: Sequence[basic_agent.BasicAgent], + phones: Sequence[apps.Phone], + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + clock: game_clock.MultiIntervalClock, + memory_factory: blank_memories.MemoryFactory, + log_color: str = 'magenta', + verbose: bool = False, + semi_verbose: bool = True, + ): + self._players = players + self._phones = phones + self._model = model + self._clock = clock + self._memory_factory = memory_factory + self._memory = memory + self._logger = logging.Logger(log_color, verbose, semi_verbose) + + def name(self): + return 'State of phone' + + def _is_phone_event(self, event_statement: str) -> bool: + document = interactive_document.InteractiveDocument(self._model) + document.statement(f'Event: {event_statement}') + + return document.yes_no_question( + 'Did a player interact with their smartphone as part of this event?' + ) + + def _get_player_from_event( + self, event_statement: str + ) -> basic_agent.BasicAgent | None: + document = interactive_document.InteractiveDocument(self._model) + document.statement( + f'Event: {event_statement}. This event states that someone interacted' + ' with their phone.' + ) + + for player in self._players: + is_player_using_phone = helper_functions.filter_copy_as_statement( + document + ).yes_no_question( + f'Does the event description explicitly state that {player.name}' + ' interacted with their phone?' + ) + if is_player_using_phone: + return player + + return None + + def _get_phone(self, player_name: str) -> apps.Phone: + return next(p for p in self._phones if p.player_name == player_name) + + def _get_player_using_phone( + self, event_statement: str + ) -> basic_agent.BasicAgent | None: + self._logger.semi_verbose('Checking if the phone was used...') + + if not self._is_phone_event(event_statement): + self._logger.semi_verbose('The phone was not used.') + return None + + player = self._get_player_from_event(event_statement) + + if player is None: + self._logger.semi_verbose('The phone was not used.') + else: + self._logger.semi_verbose(f'Player using the phone: {player.name}') + return player + + def _run_phone_scene(self, player: basic_agent.BasicAgent): + phone_scene = scene.build( + player, + self._get_phone(player.name), + clock=self._clock, + model=self._model, + memory_factory=self._memory_factory, + ) + with self._clock.higher_gear(): + scene_output = phone_scene.run_episode() + + for event in scene_output: + player.observe(event) + self._memory.add(event) + return scene_output + + def update_after_event(self, event_statement: str): + player = self._get_player_using_phone(event_statement) + if player is not None: + self._run_phone_scene(player) + + def partial_state(self, player_name: str): + return self._get_phone(player_name).description() diff --git a/concordia/examples/three_key_questions.ipynb b/concordia/examples/three_key_questions.ipynb new file mode 100644 index 00000000..dc032ad1 --- /dev/null +++ b/concordia/examples/three_key_questions.ipynb @@ -0,0 +1,772 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "VE-6f595AybO" + }, + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9Lj2tJYdLfEU" + }, + "source": [ + "```\n", + "Copyright 2023 DeepMind Technologies Limited.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zWgEkOAO9OVz" + }, + "source": [ + "# An example implementing the three key questions\n", + "\n", + "March and Olsen (2011) posit that humans generally act as though they choose their actions by answering three key questions:\n", + "\n", + "1. What kind of situation is this?\n", + "2. What kind of person am I?\n", + "3. What does a person such as I do in a situation such as this?\n", + "\n", + "The agents used in this example implement exactly these components, and nothing else.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "J2TwJrZ08wXz" + }, + "source": [ + "## Init and import" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-qLG5ExLqpWa" + }, + "outputs": [], + "source": [ + "# @title Imports\n", + "\n", + "import concurrent.futures\n", + "import datetime\n", + "\n", + "from google.colab import widgets\n", + "from IPython import display\n", + "\n", + "from concordia.agents import basic_agent\n", + "from concordia.agents import components\n", + "from concordia.associative_memory import associative_memory\n", + "from concordia.associative_memory import blank_memories\n", + "from concordia.associative_memory import embedder_st5\n", + "from concordia.associative_memory import formative_memories\n", + "from concordia.associative_memory import importance_function\n", + "from concordia.clocks import game_clock\n", + "from concordia.environment import components as gm_components\n", + "from concordia.environment import game_master\n", + "from concordia.environment.metrics import common_sense_morality\n", + "from concordia.environment.metrics import goal_achievement\n", + "from concordia.environment.metrics import reputation\n", + "from concordia.language_model import sax_model\n", + "from concordia.utils import html as html_lib\n", + "from concordia.utils import plotting\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "I3OtW8flCJSC" + }, + "outputs": [], + "source": [ + "# Setup sentence encoder\n", + "embedder = embedder_st5.EmbedderST5()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cugwvFIKv5AS" + }, + "outputs": [], + "source": [ + "# @title SAX Language Model\n", + "\n", + "# Add path to your SAX server here:\n", + "SAX_PATH = '' # @param {type:\"string\"}\n", + "DEFAULT_MAX_TOKENS = 300 # @param {type: 'integer'}\n", + "DEFAULT_TIMEOUT_SECONDS = 60 # @param {type: 'number'}\n", + "\n", + "model = sax_model.SAXLanguageModel(SAX_PATH)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z9HYjZgyakc_" + }, + "source": [ + "## Configuring the genereric knowledge of players and GM." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TeVYseoD2WYa" + }, + "outputs": [], + "source": [ + "#@title Make the clock\n", + "time_step = datetime.timedelta(minutes=20)\n", + "SETUP_TIME = datetime.datetime(hour=20, year=2024, month=10, day=1)\n", + "\n", + "START_TIME = datetime.datetime(hour=18, year=2024, month=10, day=2)\n", + "clock = game_clock.MultiIntervalClock(\n", + " start=SETUP_TIME,\n", + " step_sizes=[time_step, datetime.timedelta(seconds=10)])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "h4_gUs6wrjPM" + }, + "outputs": [], + "source": [ + "#@title Importance models\n", + "importance_model = importance_function.AgentImportanceModel(model)\n", + "importance_model_gm = importance_function.ConstantImportanceModel()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "b8vWoQ6by51N" + }, + "outputs": [], + "source": [ + "# @title Generic memories are memories that all players and GM share.\n", + "\n", + "shared_memories = [\n", + " 'There is a pub called The Sundrop Saloon.',\n", + " \"Alice stole Bob's car and crashed it.\",\n", + " ('Alice, Bob, Charlie and Dorothy always spend their evenings at the ' +\n", + " 'Sundrop Saloon.')\n", + "]\n", + "\n", + "# The generic context will be used for the NPC context. It reflects general\n", + "# knowledge and is possessed by all characters.\n", + "shared_context = model.sample_text(\n", + " 'Summarize the following passage in a concise and insightful fashion:\\n'\n", + " + '\\n'.join(shared_memories)\n", + " + '\\n'\n", + " + 'Summary:'\n", + ")\n", + "print(shared_context)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qt8CK2mMbD7q" + }, + "source": [ + "## Configure and build the players\n", + "\n", + "---\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CrmDfTNHCVXC" + }, + "outputs": [], + "source": [ + "blank_memory_factory = blank_memories.MemoryFactory(\n", + " model=model,\n", + " embedder=embedder,\n", + " importance=importance_model.importance,\n", + " clock_now=clock.now,\n", + ")\n", + "\n", + "formative_memory_factory = formative_memories.FormativeMemoryFactory(\n", + " model=model,\n", + " shared_memories=shared_memories,\n", + " blank_memory_factory_call=blank_memory_factory.make_blank_memory,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "AXnq6aOZ3ukY" + }, + "outputs": [], + "source": [ + "#@title Creating character backgrounds, goals and traits. Modify to explore how it influences the outcomes\n", + "NUM_PLAYERS = 4\n", + "\n", + "scenario_premise = [\n", + "\n", + " (\n", + " 'Alice, Bob, Charlie and Dorothy are at the Sundrop Saloon. There '\n", + " + 'is a snow storm and they have to wait it out inside.'\n", + " ),\n", + "]\n", + "player_configs = [\n", + " formative_memories.AgentConfig(\n", + " name='Alice',\n", + " gender='female',\n", + " goal='Alice wants Bob to accept his car is trashed and back off.',\n", + " context=shared_context,\n", + " traits='responsibility: high; aggression: low',\n", + " ),\n", + " formative_memories.AgentConfig(\n", + " name='Bob',\n", + " gender='male',\n", + " goal='Bob wants Alice to pay for his car.',\n", + " context=shared_context,\n", + " traits='responsibility: high; aggression: low',\n", + " ),\n", + " formative_memories.AgentConfig(\n", + " name='Charlie',\n", + " gender='male',\n", + " goal='Charlie wants Alice to apologise.',\n", + " context=shared_context,\n", + " traits='responsibility: low; aggression: high',\n", + " ),\n", + " formative_memories.AgentConfig(\n", + " name='Dorothy',\n", + " gender='female',\n", + " goal=(\n", + " 'Dorothy wants to create a conflict between Bob and Alice, because'\n", + " ' it is funny.'\n", + " ),\n", + " context=shared_context,\n", + " traits='responsibility: medium; aggression: high',\n", + " ),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4aS2sY22B1JQ" + }, + "outputs": [], + "source": [ + "def build_agent(agent_config):\n", + "\n", + " mem = formative_memory_factory.make_memories(agent_config)\n", + "\n", + " self_perception = components.self_perception.SelfPerception(\n", + " name='self perception',\n", + " model=model,\n", + " memory=mem,\n", + " agent_name=agent_config.name,\n", + " state_clock=clock,\n", + " verbose=True,\n", + " )\n", + " situation_perception = components.situation_perception.SituationPerception(\n", + " name='situation perception',\n", + " model=model,\n", + " memory=mem,\n", + " agent_name=agent_config.name,\n", + " state_clock=clock,\n", + " verbose=True,\n", + " )\n", + " person_by_situation = components.person_by_situation.PersonBySituation(\n", + " name='person by situation',\n", + " model=model,\n", + " memory=mem,\n", + " agent_name=agent_config.name,\n", + " state_clock=clock,\n", + " components=[self_perception, situation_perception],\n", + " verbose=True,\n", + " )\n", + " persona = components.sequential.Sequential(\n", + " name='persona',\n", + " components=[\n", + " self_perception,\n", + " situation_perception,\n", + " person_by_situation,\n", + " ],\n", + " )\n", + " current_time_component = components.report_state.ReportState(name='Current time',\n", + " get_state=clock.current_time_interval_str)\n", + "\n", + " current_obs = components.observation.Observation(agent_config.name, mem)\n", + " summary_obs = components.observation.ObservationSummary(\n", + " model=model,\n", + " agent_name=agent_config.name,\n", + " components=[persona],\n", + " )\n", + " agent = basic_agent.BasicAgent(\n", + " model,\n", + " mem,\n", + " agent_name=agent_config.name,\n", + " clock=clock,\n", + " verbose=False,\n", + " components=[persona, current_time_component,summary_obs, current_obs],\n", + " update_interval=time_step,\n", + " )\n", + " return agent\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5RU3ZV4oIknW" + }, + "outputs": [], + "source": [ + "player_configs = player_configs[:NUM_PLAYERS]\n", + "\n", + "players = []\n", + "\n", + "with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_PLAYERS) as pool:\n", + " for agent in pool.map(build_agent, player_configs[:NUM_PLAYERS]):\n", + " players.append(agent)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2vt8ggYUrW8M" + }, + "source": [ + "## Build GM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "siwglxrc6z2j" + }, + "outputs": [], + "source": [ + "game_master_instructions = (\n", + " 'This is a social science experiment. It is structured as a '\n", + " 'tabletop roleplaying game (like dungeons and dragons). You are the '\n", + " 'game master. You will describe the current situation to the '\n", + " 'participants in the experiment and then on the basis of what you '\n", + " 'tell them they will suggest actions for the character they control. '\n", + " 'Aside from you, each other participant controls just one character. '\n", + " 'You are the game master so you may control any non-player '\n", + " 'character. You will track the state of the world and keep it '\n", + " 'consistent as time passes in the simulation and the participants '\n", + " 'take actions and change things in their world. Remember that this '\n", + " 'is a serious social science experiment. It is not just a game. It '\n", + " 'need not be fun for the participants. Always use third-person '\n", + " 'limited perspective, even when speaking directly to the participants. '\n", + " 'Players can not leave the Sundrop Saloon, since it is snowed in.'\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3W65kHOKQwrv" + }, + "outputs": [], + "source": [ + "game_master_memory = associative_memory.AssociativeMemory(\n", + " sentence_embedder=embedder,\n", + " importance=importance_model_gm.importance,\n", + " clock=clock.now)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-cxivChc633z" + }, + "outputs": [], + "source": [ + "# @title Create components of the Game Master\n", + "player_names = [player.name for player in players]\n", + "\n", + "instructions_construct = components.constant.ConstantConstruct(\n", + " state=game_master_instructions,\n", + " name='Instructions')\n", + "scenario_knowledge = components.constant.ConstantConstruct(\n", + " state=' '.join(shared_memories),\n", + " name='Background')\n", + "\n", + "player_status = gm_components.player_status.PlayerStatus(\n", + " clock_now=clock.now,\n", + " model=model,\n", + " memory=game_master_memory,\n", + " player_names=player_names)\n", + "\n", + "\n", + "convo_externality = gm_components.conversation.Conversation(\n", + " players=players,\n", + " model=model,\n", + " memory=game_master_memory,\n", + " clock=clock,\n", + " burner_memory_factory=blank_memory_factory,\n", + " components=[player_status],\n", + " cap_nonplayer_characters=3,\n", + " game_master_instructions=game_master_instructions,\n", + " shared_context=shared_context,\n", + " verbose=False,\n", + ")\n", + "\n", + "direct_effect_externality = gm_components.direct_effect.DirectEffect(\n", + " players=players,\n", + " model=model,\n", + " memory=game_master_memory,\n", + " clock_now=clock.now,\n", + " verbose=False,\n", + " components=[player_status]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5SpNVmlh6_hp" + }, + "outputs": [], + "source": [ + "# @title Metrics\n", + "player_goals = {\n", + " player_config.name: player_config.goal for player_config in player_configs\n", + "}\n", + "\n", + "goal_metric = goal_achievement.GoalAchievementMetric(\n", + " model, player_goals, clock, 'Goal achievement', verbose=False)\n", + "morality_metric = common_sense_morality.CommonSenseMoralityMetric(\n", + " model, players, clock, 'Morality', verbose=False)\n", + "reputation_metric = reputation.ReputationMetric(\n", + " model, players, clock, 'Reputation', verbose=False)\n", + "\n", + "metrics = [goal_metric, morality_metric, reputation_metric]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "d_R2BVNOsAwa" + }, + "outputs": [], + "source": [ + "# @title Create the game master object\n", + "env = game_master.GameMaster(\n", + " model=model,\n", + " memory=game_master_memory,\n", + " clock=clock,\n", + " players=players,\n", + " components=[\n", + " instructions_construct,\n", + " scenario_knowledge,\n", + " player_status,\n", + " convo_externality,\n", + " direct_effect_externality,\n", + " ],\n", + " measurements=metrics,\n", + " randomise_initiative=True,\n", + " player_observes_event=False,\n", + " verbose=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LXykV_TdwfKq" + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d2u0bQ1MSCGd" + }, + "source": [ + "## The RUN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hdTRDaxEZZnN" + }, + "outputs": [], + "source": [ + "clock.set(START_TIME)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9IggLF1aH_hF" + }, + "outputs": [], + "source": [ + "for premis in scenario_premise:\n", + " game_master_memory.add(premis)\n", + " for player in players:\n", + " player.observe(premis)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2Bt87stq76gF" + }, + "outputs": [], + "source": [ + "# @title Expect about 2-3 minutes per step.\n", + "episode_length = 3 # @param {type: 'integer'}\n", + "for _ in range(episode_length):\n", + " env.step()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DnwvpvQ4bnFs" + }, + "source": [ + "## Summary and analysis of the episode" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5U5FDXvs4HSr" + }, + "outputs": [], + "source": [ + "# @title Metrics plotting\n", + "tb = widgets.TabBar([metric.name() for metric in metrics])\n", + "\n", + "for metric in metrics:\n", + " with tb.output_to(metric.name()):\n", + " plotting.plot_metric_line(metric)\n", + " plotting.plot_metric_pie(metric)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "j71OiuPot5UV" + }, + "source": [ + "## Save results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "O4jp0xGXvOAJ" + }, + "outputs": [], + "source": [ + "# @title Summarize the entire story.\n", + "all_gm_memories = env._memory.retrieve_recent(k=10000, add_time=True)\n", + "\n", + "detailed_story = '\\n'.join(all_gm_memories)\n", + "print('len(detailed_story): ', len(detailed_story))\n", + "# print(detailed_story)\n", + "\n", + "episode_summary = model.sample_text(\n", + " f'Sequence of events:\\n{detailed_story}'+\n", + " '\\nNarratively summarize the above temporally ordered ' +\n", + " 'sequence of events. Write it as a news report. Summary:\\n',\n", + " max_characters=8000, max_tokens=8000, terminators=())\n", + "print(episode_summary)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ALG987t-6j-V" + }, + "outputs": [], + "source": [ + "# @title Summarise the perspective of each player\n", + "player_logs = []\n", + "player_log_names = []\n", + "for player in players:\n", + " name = player.name\n", + " detailed_story = '\\n'.join(player._memory.retrieve_recent(k=1000,\n", + " add_time=True))\n", + " summary = ''\n", + " summary = model.sample_text(\n", + " f'Sequence of events that happened to {name}:\\n{detailed_story}'\n", + " '\\nWrite a short story that summarises these events.\\n'\n", + " ,\n", + " max_characters=8000, max_tokens=8000, terminators=())\n", + "\n", + " all_player_mem = player._memory.retrieve_recent(k=1000, add_time=True)\n", + " all_player_mem = ['Summary:', summary, 'Memories:'] + all_player_mem\n", + " player_html = html_lib.PythonObjectToHTMLConverter(all_player_mem).convert()\n", + " player_logs.append(player_html)\n", + " player_log_names.append(f'{name}')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UmPOvjVxddye" + }, + "source": [ + "#Build and display HTML log of the experiment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JyEoGgI05xI0" + }, + "outputs": [], + "source": [ + "history_sources = [env, direct_effect_externality, convo_externality]\n", + "histories_html = [\n", + " html_lib.PythonObjectToHTMLConverter(history.get_history()).convert()\n", + " for history in history_sources]\n", + "histories_names = [history.name() for history in history_sources]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XNJuo4Dwt5Ui" + }, + "outputs": [], + "source": [ + "gm_mem_html = html_lib.PythonObjectToHTMLConverter(all_gm_memories).convert()\n", + "\n", + "tabbed_html = html_lib.combine_html_pages(\n", + " histories_html + [gm_mem_html] + player_logs,\n", + " histories_names + ['GM'] + player_log_names,\n", + " summary=episode_summary,\n", + " title='Friends in a pub experiment',\n", + ")\n", + "\n", + "tabbed_html = html_lib.finalise_html(tabbed_html)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pumxvmrzANOq" + }, + "outputs": [], + "source": [ + "display.HTML(tabbed_html)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HX-M9Im_dneG" + }, + "source": [ + "#Interact with a specific player" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ESJ1l7_Kt5Uj" + }, + "outputs": [], + "source": [ + "sim_to_interact = 'Alice' # @param ['Alice', 'Bob','Charlie', 'Dorothy', 'Ellen'] {type:\"string\"}\n", + "user_identity = 'a close friend' # @param {type:\"string\"}\n", + "interaction_premise = f'{sim_to_interact} is talking to {user_identity}\\n' # @param {type:\"string\"}\n", + "\n", + "player_names = [player.name for player in players]\n", + "player_by_name = {player.name: player for player in players}\n", + "selected_player = player_by_name[sim_to_interact]\n", + "interrogation = interaction_premise" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5Q1cYflLt5Uj" + }, + "outputs": [], + "source": [ + "utterence_from_user = 'Did Bob accept your appology?' # @param {type:\"string\"}\n", + "\n", + "interrogation += f'{user_identity}: {utterence_from_user}'\n", + "player_says = selected_player.say(interrogation)\n", + "interrogation += f'\\n{sim_to_interact}: {player_says}\\n'\n", + "print(interrogation)" + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "", + "kind": "private" + }, + "private_outputs": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/concordia/language_model/__init__.py b/concordia/language_model/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/concordia/language_model/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/concordia/language_model/gcloud_model.py b/concordia/language_model/gcloud_model.py new file mode 100644 index 00000000..68946899 --- /dev/null +++ b/concordia/language_model/gcloud_model.py @@ -0,0 +1,106 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Google Cloud Language Model.""" + +from collections.abc import Collection, Sequence +import sys + +from concordia.language_model import language_model +from concordia.utils import text +from google import auth +import vertexai +from vertexai.preview import language_models as vertex_models + +DEFAULT_MAX_TOKENS = 50 +MAX_MULTIPLE_CHOICE_ATTEMPTS = 20 + + +class CloudLanguageModel(language_model.LanguageModel): + """Language model via a google cloud API.""" + + def __init__( + self, + project_id: str, + model_name: str = 'text-bison@001', + location: str = 'us-central1', + credentials: auth.credentials.Credentials = None + ) -> None: + """Initializes a model instance using the Google Cloud language model API. + + Args: + project_id: Google Cloud project id in API calls. + model_name: which language model to use + location: The location to use when making API calls. + credentials: Custom credentials to use when making API calls. If not + provided credentials will be ascertained from the environment. + """ + if not credentials: + credentials = auth.default()[0] + vertexai.init( + project=project_id, location=location, credentials=credentials) + self._model = vertex_models.TextGenerationModel.from_pretrained(model_name) + + def sample_text( + self, + prompt: str, + *, + timeout: float = None, + max_tokens: int = DEFAULT_MAX_TOKENS, + max_characters: int = sys.maxsize, + terminators: Collection[str] = (), + temperature: float = 0.5, + seed: int | None = None, + ) -> str: + """See base class.""" + if timeout is not None: + raise NotImplementedError('Unclear how to set timeout for cloud models.') + if seed is not None: + raise NotImplementedError('Unclear how to set seed for cloud models.') + + max_tokens = min(max_tokens, max_characters) + sample = self._model.predict( + prompt, + temperature=temperature, + max_output_tokens=max_tokens,) + return text.truncate( + sample.text, max_length=max_characters, delimiters=terminators + ) + + def sample_choice( + self, + prompt: str, + responses: Sequence[str], + *, + seed: int | None = None, + ) -> tuple[int, str, dict[str, float]]: + """See base class.""" + max_characters = max([len(response) for response in responses]) + + for _ in range(MAX_MULTIPLE_CHOICE_ATTEMPTS): + sample = self.sample_text( + prompt, + max_tokens=1, + max_characters=max_characters, + temperature=0.0, + seed=seed) + try: + idx = responses.index(sample) + except ValueError: + continue + else: + debug = {} + return idx, responses[idx], debug + + raise language_model.InvalidResponseError( + 'Too many multiple choice attempts.') diff --git a/concordia/language_model/language_model.py b/concordia/language_model/language_model.py new file mode 100644 index 00000000..c0810fce --- /dev/null +++ b/concordia/language_model/language_model.py @@ -0,0 +1,93 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Base class for a language model.""" + +import abc +from collections.abc import Collection, Mapping, Sequence +import sys +from typing import Any + +DEFAULT_MAX_TOKENS = 50 +DEFAULT_TEMPERATURE = 0.5 +DEFAULT_MAX_CHARACTERS = sys.maxsize +DEFAULT_TERMINATORS = () + + +class InvalidResponseError(Exception): + """Exception to throw when exceeding max attempts to get a choice.""" + pass + + +class LanguageModel(metaclass=abc.ABCMeta): + """Language model from LRL library.""" + + @abc.abstractmethod + def sample_text( + self, + prompt: str, + *, + max_tokens: int = DEFAULT_MAX_TOKENS, + max_characters: int = DEFAULT_MAX_CHARACTERS, + terminators: Collection[str] = DEFAULT_TERMINATORS, + temperature: float = DEFAULT_TEMPERATURE, + seed: int | None = None, + ) -> str: + """Samples text from the model. + + NOTE: Sampling method is up to the underlying implementation and may not + reflect the underlying log_probabilities. + + Args: + prompt: the initial text to condition on. + max_tokens: the maximum number of tokens in the response. + max_characters: the maximum number of characters in the response. + terminators: the response will be terminated before any of these + characters. + temperature: temperature for the model. + seed: optional seed for the sampling. If None a random seed will be used. + + Returns: + The sampled response (i.e. does not iclude the prompt). + """ + raise NotImplementedError + + @abc.abstractmethod + def sample_choice( + self, + prompt: str, + responses: Sequence[str], + *, + seed: int | None = None, + ) -> tuple[int, str, Mapping[str, Any]]: + """Samples a response from those available. + + NOTE: Sampling method is up to the underlying implementation and may not + reflect the underlying log_probabilities. + + Args: + prompt: the initial text to condition on. + responses: the responses to score. + seed: optional seed for the sampling. If None a random seed will be used. + + Returns: + (index, response, info). The index of the sampled response, the sampled + response, and some info about the sampling process. + + Raises: + InvalidResponseError if unable to produce a valid choice after attempting + a number of times. + """ + raise NotImplementedError diff --git a/concordia/language_model/retry_wrapper.py b/concordia/language_model/retry_wrapper.py new file mode 100644 index 00000000..ad234011 --- /dev/null +++ b/concordia/language_model/retry_wrapper.py @@ -0,0 +1,88 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper to retry calls to an underlying language model.""" + +from collections.abc import Collection, Sequence +import copy +from typing import Any, Mapping, Tuple, Type + +from concordia.language_model import language_model +import retry + + +class RetryLanguageModel(language_model.LanguageModel): + """Wraps an underlying language model and retries calls to it.""" + + def __init__( + self, + model: language_model.LanguageModel, + retry_on_exceptions: Collection[Type[Exception]] = (Exception,), + retry_tries: float = 3., + retry_delay: float = 2., + jitter: Tuple[float, float] = (0.0, 1.0), + ) -> None: + """Wrap the underlying language model with retries on given exceptions. + + Args: + model: A language model to wrap with retries. + retry_on_exceptions: the exception exceptions to retry on. + retry_tries: number of retries before failing. + retry_delay: minimum delay between retries. + jitter: tuple of minimum and maximum jitter to add to the retry. + """ + self._model = model + self._retry_on_exceptions = copy.deepcopy(retry_on_exceptions) + self._retry_tries = retry_tries + self._retry_delay = retry_delay + self._jitter = jitter + + def sample_text( + self, + prompt: str, + *, + max_tokens: int = language_model.DEFAULT_MAX_TOKENS, + max_characters: int = language_model.DEFAULT_MAX_CHARACTERS, + terminators: Collection[str] = language_model.DEFAULT_TERMINATORS, + temperature: float = language_model.DEFAULT_TEMPERATURE, + seed: int | None = None, + ) -> str: + """See base class.""" + @retry.retry(self._retry_on_exceptions, tries=self._retry_tries, + delay=self._retry_delay, jitter=self._jitter) + def _sample_text(model, prompt, *, max_tokens=max_tokens, + max_characters=max_characters, terminators=terminators, + temperature=temperature, seed=seed): + return model.sample_text( + prompt, max_tokens=max_tokens, max_characters=max_characters, + terminators=terminators, temperature=temperature, seed=seed) + + return _sample_text(self._model, prompt, max_tokens=max_tokens, + max_characters=max_characters, terminators=terminators, + temperature=temperature, seed=seed) + + def sample_choice( + self, + prompt: str, + responses: Sequence[str], + *, + seed: int | None = None, + ) -> tuple[int, str, Mapping[str, Any]]: + """See base class.""" + @retry.retry(self._retry_on_exceptions, tries=self._retry_tries, + delay=self._retry_delay, jitter=self._jitter) + def _sample_choice(model, prompt, responses, *, seed): + return model.sample_choice(prompt, responses, seed=seed) + + return _sample_choice(self._model, prompt, responses, seed=seed) diff --git a/concordia/language_model/sax_model.py b/concordia/language_model/sax_model.py new file mode 100644 index 00000000..a6e7a1dc --- /dev/null +++ b/concordia/language_model/sax_model.py @@ -0,0 +1,149 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Language Model that uses Saxml server. + +https://github.com/google/saxml +""" + +from collections.abc import Collection, Sequence +import concurrent.futures +import sys + +from concordia.language_model import language_model +from concordia.utils import text +import numpy as np +from saxml.client.python import sax +from scipy import special + +DEFAULT_MAX_TOKENS = 50 +DEFAULT_TIMEOUT_SECONDS = 60 +DEFAULT_NUM_CONNECTIONS = 3 + + +class SAXLanguageModel(language_model.LanguageModel): + """Language Model that uses Saxml server.""" + + def __init__( + self, + path: str, + num_conn: int = DEFAULT_NUM_CONNECTIONS, + deterministic_multiple_choice=False, + ) -> None: + """Initializes the instance. + + Args: + path: sax path of model. + num_conn: preferred number of connections to sax backend. + deterministic_multiple_choice: if True, sample_response returns the + response with max probability instead of sampling. + """ + options = sax.Options() + options.num_conn = num_conn + self._model = sax.Model(path, options).LM() + self._deterministic_multiple_choice = deterministic_multiple_choice + + def sample_text( + self, + prompt: str, + *, + timeout: float = DEFAULT_TIMEOUT_SECONDS, + max_tokens: int = DEFAULT_MAX_TOKENS, + max_characters: int = sys.maxsize, + terminators: Collection[str] = (), + temperature: float = 0.5, + seed: int | None = None, + ) -> str: + """Samples a string from the model. + + Args: + prompt: the prompt to generate a response for. + timeout: timeout for the request. + max_tokens: maximum number of tokens to generate. + max_characters: maximum number of characters to generate. + terminators: delimiters to use in the generated response. + temperature: temperature for the model. + seed: seed for the random number generator. + + Returns: + A string of the generated response. + """ + if seed is not None: + raise NotImplementedError('Unclear how to set seed for sax models.') + max_tokens = min(max_tokens, max_characters) + options = sax.ModelOptions() + options.SetTimeout(timeout) + options.SetExtraInput('per_example_max_decode_steps', max_tokens) + options.SetExtraInput('temperature', temperature) + (sample, _), *_ = self._model.Generate(prompt, options) + return text.truncate( + sample, max_length=max_characters, delimiters=terminators + ) + + def sample_choice( + self, + prompt: str, + responses: Sequence[str], + *, + seed: int | None = None, + ) -> tuple[int, str, dict[str, float]]: + """Samples a response from the model. + + Args: + prompt: the prompt to generate a response for. + responses: the responses to sample. + seed: seed for the random number generator. + + Returns: + A tuple of (index, response, debug). + """ + scores = self._score_responses(prompt, responses) + probs = special.softmax(scores) + entropy = probs @ np.log(probs) + if self._deterministic_multiple_choice: + idx = np.argmax(probs, axis=0) + else: + idx = np.random.default_rng(seed).choice(len(probs), p=probs) + debug = {'probs': probs, 'entropy': entropy} + return idx, responses[idx], debug + + def _score_responses( + self, + prompt: str, + responses: Sequence[str], + ) -> np.ndarray: + """Returns the relative log_likelihood of the provided responses. + + Args: + prompt: the prompt preceding the response. + responses: the responses to score. + + Returns: + log Pr(response|prompt) + """ + if isinstance(responses, str): + raise TypeError('responses must be a Sequence') + + def get_score(response, model): + return model.Score(prompt, [response])[0] + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(get_score, response, self._model) + for response in responses + ] + scores = [future.result() for future in futures] + + return np.array(list(scores)) diff --git a/concordia/metrics/__init__.py b/concordia/metrics/__init__.py new file mode 100644 index 00000000..4ad71829 --- /dev/null +++ b/concordia/metrics/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/concordia/metrics/common_sense_morality.py b/concordia/metrics/common_sense_morality.py new file mode 100644 index 00000000..550d8707 --- /dev/null +++ b/concordia/metrics/common_sense_morality.py @@ -0,0 +1,109 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Metrics of ethical conduct.""" + +from collections.abc import Sequence + +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import clock as game_clock +from concordia.typing import component +from concordia.utils import measurements as measurements_lib + + +DEFAULT_SCALE = ('abhorrent', 'wrong', 'neutral', 'right', 'praiseworthy') +DEFAULT_CHANNEL_NAME = 'common_sense_morality' + + +class CommonSenseMoralityMetric(component.Component): + """Metric of common sense morality of players actions.""" + + def __init__( + self, + model: language_model.LanguageModel, + player_name: str, + clock: game_clock.GameClock, + name: str = 'Morality', + scale: Sequence[str] | None = None, + verbose: bool = False, + measurements: measurements_lib.Measurements | None = None, + channel: str = DEFAULT_CHANNEL_NAME, + ): + """Initializes the metric. + + Args: + model: The language model to use. + player_name: The player to ask about. + clock: The clock to use. + name: The name of the metric. + scale: The scale of the metric, uses default if None. + verbose: Whether to print the metric. + measurements: The measurements to use. + channel: The name of the channel to push data + """ + self._model = model + self._name = name + self._clock = clock + self._verbose = verbose + self._player_name = player_name + self._scale = scale or DEFAULT_SCALE + self._measurements = measurements + self._channel = channel + + self._timestep = 0 + + def name( + self, + ) -> str: + """See base class.""" + return self._name + + def observe(self, observation: str) -> None: + """See base class.""" + doc = interactive_document.InteractiveDocument(self._model) + answer = doc.multiple_choice_question( + ( + f'Evaluate if the following activity of {self._player_name} using' + f' common sense morality.\nActivity: {observation}\n' + ), + self._scale, + ) + answer_str = self._scale[answer] + + answer = float(answer) / float(len(self._scale) - 1) + + datum = { + 'time_str': self._clock.now().strftime('%H:%M:%S'), + 'clock_step': self._clock.get_step(), + 'timestep': self._timestep, + 'value_float': answer, + 'value_str': answer_str, + 'player': self._player_name, + } + if self._measurements: + self._measurements.publish_datum(self._channel, datum) + + datum['time'] = self._clock.now() + + if self._verbose: + print(f'{self._name} of {self._player_name}: {answer_str}') + self._timestep += 1 + + def state( + self, + ) -> str | None: + """Returns the current state of the component.""" + return '' diff --git a/concordia/metrics/goal_achievement.py b/concordia/metrics/goal_achievement.py new file mode 100644 index 00000000..6a18f9d2 --- /dev/null +++ b/concordia/metrics/goal_achievement.py @@ -0,0 +1,116 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Metric to track goal achievement for a player.""" + +from collections.abc import Sequence + +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import clock as game_clock +from concordia.typing import component +from concordia.utils import measurements as measurements_lib + +DEFAULT_SCALE = ( + 'activity unrelated to the goal', + 'somewhat working towards the goal', + 'working towards the goal', + 'goal achieved', +) +DEFAULT_CHANNEL_NAME = 'goal_achievement' + + +class GoalAchievementMetric(component.Component): + """Metric of goal achievement for a player and its goal.""" + + def __init__( + self, + model: language_model.LanguageModel, + player_name: str, + player_goal: str, + clock: game_clock.GameClock, + name: str = 'Goal Achievement', + scale: Sequence[str] = DEFAULT_SCALE, + measurements: measurements_lib.Measurements | None = None, + channel: str = DEFAULT_CHANNEL_NAME, + verbose: bool = False, + ): + """Initializes the metric. + + Args: + model: Language model to use for the question. + player_name: player name. + player_goal: player goal. + clock: Clock for logging. + name: Name of the metric. + scale: Scale of the metric, uses default if None. + measurements: The measurements object to publish data to. + channel: Channel to use for logging the metric. + verbose: Whether to print logs during execution. + """ + self._model = model + self._player_name = player_name + self._player_goal = player_goal + self._clock = clock + self._name = name + self._scale = scale + self._measurements = measurements + self._channel = channel + self._verbose = verbose + + self._timestep = 0 + + def name( + self, + ) -> str: + """See base class.""" + return self._name + + def observe(self, observation: str) -> None: + """See base class.""" + doc = interactive_document.InteractiveDocument(self._model) + answer = doc.multiple_choice_question( + ( + 'Evaluate if the following activity brings' + f' {self._player_name} closer to their goal' + f' "{self._player_name} .\n Activity: {observation}\n' + ), + self._scale, + ) + answer_str = self._scale[answer] + + answer = float(answer) / float(len(self._scale) - 1) + + datum = { + 'time_str': self._clock.now().strftime('%H:%M:%S'), + 'clock_step': self._clock.get_step(), + 'timestep': self._timestep, + 'value_float': answer, + 'value_str': answer_str, + 'player': self._player_name, + 'goal': self._player_goal, + } + datum['time'] = self._clock.now() + + if self._measurements: + self._measurements.publish_datum(self._channel, datum) + if self._verbose: + print(f'{self._name} of {self._player_name}: {answer_str}') + self._timestep += 1 + + def state( + self, + ) -> str | None: + """Returns the current state of the component.""" + return '' diff --git a/concordia/metrics/opinion_of_others.py b/concordia/metrics/opinion_of_others.py new file mode 100644 index 00000000..b9332133 --- /dev/null +++ b/concordia/metrics/opinion_of_others.py @@ -0,0 +1,162 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Metric of player's opinion of other players.""" + +from collections.abc import Sequence +import concurrent.futures +from typing import Callable + +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import clock as game_clock +from concordia.typing import component +from concordia.utils import measurements as measurements_lib + +DEFAULT_SCALE = ( + 'very negative', + 'somewhat negative', + 'neutral', + 'somewhat positive', + 'very positive', +) +DEFAULT_CHANNEL_NAME = 'opinion_of_others' + + +class OpinionOfOthersMetric(component.Component): + """Metric of opinion of other players by a player. + + This component triggers a series of questions on `update`, one for each player + in `player_names`. The context for all questions is given by the callable + `context_fn`, which is called only once. The responses to the question are + evaluated with the given scale, and logged as a datum in the specified channel + of the measurements. + """ + + def __init__( + self, + *, + model: language_model.LanguageModel, + player_name: str, + player_names: Sequence[str], + context_fn: Callable[[], str], + clock: game_clock.GameClock, + name: str = 'Opinion', + scale: Sequence[str] = DEFAULT_SCALE, + verbose: bool = False, + measurements: measurements_lib.Measurements | None = None, + channel: str = DEFAULT_CHANNEL_NAME, + question: str = 'What is {opining_player}\'s opinion of {of_player}?', + ): + """Initializes the metric. + + Args: + model: Language model to use for the question. + player_name: The name of the player opining on others. + player_names: List of player names, might include the opining player. + context_fn: The function to get the context text for the question. + (typically this is the player state). This function will be called on + `update`. + clock: Clock for logging. + name: Name of the metric. + scale: Scale of the metric, uses default if None. + verbose: Whether to print logs during execution. + measurements: The measurements object to publish data to. + channel: Channel to use for logging the metric. + question: The question to ask the player about opinions on other players. + Must have two formatting fields: "{opining_player}" and "{of_player}". + + Raises: + ValueError: If player_names or scale are empty. + """ + self._model = model + self._name = name + self._clock = clock + self._verbose = verbose + self._player_name = player_name + if player_names: + self._player_names = list(player_names) + else: + raise ValueError('player_names must be specified.') + self._context_fn = context_fn + if scale: + self._scale = list(scale) + else: + raise ValueError('scale must be specified.') + self._measurements = measurements + self._channel = channel + self._question = question + + self._timestep = 0 + + def name( + self, + ) -> str: + """Returns the name of the measurement.""" + return self._name + + def update(self) -> None: + """See base class.""" + def get_opinion(of_player: str) -> None: + if of_player == self._player_name: + return # No self opinions. + + prompt = interactive_document.InteractiveDocument(self._model) + parent_state = self._context_fn() + prompt.statement(parent_state) + + question = self._question.format( + opining_player=self._player_name, + of_player=of_player, + ) + + answer = prompt.multiple_choice_question( + question=question, answers=self._scale, + ) + answer_str = self._scale[answer] + + answer_float = float(answer) / float(len(self._scale) - 1) + datum = { + 'time_str': self._clock.now().strftime('%H:%M:%S'), + 'clock_step': self._clock.get_step(), + 'timestep': self._timestep, + 'value_float': answer_float, + 'value_str': answer_str, + 'opining_player': self._player_name, + 'of_player': of_player, + } + if self._measurements: + self._measurements.publish_datum(self._channel, datum) + + datum['time'] = self._clock.now() + if self._verbose: + print( + f'{self._name} of {of_player} as viewed by ' + f'{self._player_name}: {answer_str}' + ) + + return + + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(self._player_names) + ) as executor: + executor.map(get_opinion, self._player_names) + self._timestep += 1 + + def state( + self, + ) -> str | None: + """Returns the current state of the component.""" + return '' diff --git a/concordia/metrics/uncertainty_scale_question.py b/concordia/metrics/uncertainty_scale_question.py new file mode 100644 index 00000000..317f07a6 --- /dev/null +++ b/concordia/metrics/uncertainty_scale_question.py @@ -0,0 +1,132 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Metric for tracking the answer to a configurable question.""" + +from collections.abc import Sequence +from typing import Callable + +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import clock as game_clock +from concordia.typing import component +from concordia.utils import measurements as measurements_lib + + +DEFAULT_SCALE = ( + 'Definitively not', + 'Maybe not', + 'Maybe yes', + 'Definitively yes', +) + +DEFAULT_QUESTION = 'Would {player_name} talk to a stranger?' +DEFAULT_CHANNEL_NAME = 'question' + + +class Question(component.Component): + """Metrics for tracking the answer to a configurable question. + + This component triggers a question on `update`. The context for the question + is given by the callable `context_fn`, which is called only once. The response + to the question is evaluated with the given scale, and logged as a datum in + the specified channel of the measurements. + """ + + def __init__( + self, + model: language_model.LanguageModel, + player_name: str, + context_fn: Callable[[], str], + clock: game_clock.GameClock, + name: str = 'Question', + question: str = DEFAULT_QUESTION, + scale: Sequence[str] = DEFAULT_SCALE, + verbose: bool = False, + measurements: measurements_lib.Measurements | None = None, + channel: str = DEFAULT_CHANNEL_NAME, + ): + """Initializes the component. + + Args: + model: The model (LLM) to use. + player_name: The name of the player. + context_fn: The function to get the parent state (typically a player) + clock: The clock of the simulation. + name: The name of the component. + question: The question to ask. Might have the formatting "{player_name}" + which will be replaced by the player's name. + scale: The possible answer options for the question. + verbose: whether to `print` the outcome. + measurements: the measurements object to publish data. + channel: Name of the channel to publish measurements to. + """ + self._model = model + self._player_name = player_name + self._context_fn = context_fn + self._clock = clock + self._name = name + self._question = question + if scale: + self._scale = list(scale) + else: + raise ValueError('scale must be specified.') + self._verbose = verbose + self._measurements = measurements + self._channel = channel + + self._timestep = 0 + + def name( + self, + ) -> str: + """Returns the name of the measurement.""" + return self._name + + def update(self) -> None: + """See base class.""" + prompt = interactive_document.InteractiveDocument(self._model) + parent_state = self._context_fn() + prompt.statement(parent_state) + + question = self._question.format(player_name=self._player_name) + + answer = prompt.multiple_choice_question( + question=question, answers=self._scale, + ) + answer_str = self._scale[answer] + + answer_float = answer / (len(self._scale) - 1) + datum = { + 'time_str': self._clock.now().strftime('%H:%M:%S'), + 'clock_step': self._clock.get_step(), + 'timestep': self._timestep, + 'value_float': answer_float, + 'value_str': answer_str, + 'player': self._player_name, + } + if self._measurements is not None: + self._measurements.publish_datum(self._channel, datum) + + datum['time'] = self._clock.now() + if self._verbose: + print(f'{question}\n{self._player_name}: {answer_str}') + self._timestep += 1 + + def state( + self, + ) -> str | None: + """Returns the current state of the component.""" + return '' diff --git a/concordia/tests/__init__.py b/concordia/tests/__init__.py new file mode 100644 index 00000000..4ad71829 --- /dev/null +++ b/concordia/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/concordia/tests/concordia_integration_test.py b/concordia/tests/concordia_integration_test.py new file mode 100644 index 00000000..f682b2be --- /dev/null +++ b/concordia/tests/concordia_integration_test.py @@ -0,0 +1,227 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from typing import List +from absl.testing import absltest +from absl.testing import parameterized +from concordia.agents import basic_agent +from concordia.agents import components +from concordia.associative_memory import associative_memory +from concordia.associative_memory import blank_memories +from concordia.associative_memory import importance_function +from concordia.clocks import game_clock +from concordia.environment import components as gm_components +from concordia.environment import game_master +from concordia.environment.metrics import common_sense_morality +from concordia.environment.metrics import goal_achievement +from concordia.environment.metrics import reputation +from concordia.tests import mock_model +import numpy as np + + +def embedder(text: str): + del text + return np.random.rand(16) + + +def _make_agent( + name: str, + model: mock_model.MockModel, + clock: game_clock.MultiIntervalClock, + game_master_instructions: str, + mem_factory: blank_memories.MemoryFactory, +) -> basic_agent.BasicAgent: + """Creates two agents with the same game master instructions.""" + mem = mem_factory.make_blank_memory() + agent = basic_agent.BasicAgent( + model, + mem, + name, + clock, + [ + components.constant.ConstantConstruct( + 'Instructions:', game_master_instructions + ), + components.constant.ConstantConstruct( + 'General knowledge:', 'this is a test' + ), + components.observation.Observation('Alice', mem), + ], + verbose=True, + ) + + return agent + + +def _make_environment( + model: mock_model.MockModel, + clock: game_clock.MultiIntervalClock, + players: List[basic_agent.BasicAgent], + game_master_instructions: str, + importance_model_gm: importance_function.ImportanceModel, +) -> game_master.GameMaster: + """Creates a game master environment.""" + game_master_memory = associative_memory.AssociativeMemory( + embedder, importance_model_gm.importance, clock=clock.now + ) + player_names = [player.name for player in players] + + shared_memories = [ + 'There is a hamlet named Riverbend.', + ] + + shared_context = 'There is a hamlet named Riverbend.' + + instructions_construct = components.constant.ConstantConstruct( + game_master_instructions, 'Instructions' + ) + facts_on_village = components.constant.ConstantConstruct( + ' '.join(shared_memories), 'General knowledge of Riverbend' + ) + player_status = gm_components.player_status.PlayerStatus( + clock.now, model, game_master_memory, player_names + ) + + mem_factory = blank_memories.MemoryFactory( + model=model, + embedder=embedder, + importance=importance_model_gm.importance, + clock_now=clock.now, + ) + + convo_externality = gm_components.conversation.Conversation( + players, + model, + memory=game_master_memory, + clock=clock, + burner_memory_factory=mem_factory, + components=[player_status], + cap_nonplayer_characters=2, + game_master_instructions=game_master_instructions, + shared_context=shared_context, + verbose=False, + ) + + direct_effect_externality = gm_components.direct_effect.DirectEffect( + players, + memory=game_master_memory, + model=model, + clock_now=clock.now, + verbose=False, + components=[player_status], + ) + + debug_event_time = datetime.datetime(hour=14, year=2024, month=10, day=1) + + schedule = { + 'start': gm_components.schedule.EventData( + time=datetime.datetime(hour=9, year=2024, month=10, day=1), + description='', + ), + 'debug_event': gm_components.schedule.EventData( + time=debug_event_time, + description='Debug event', + ), + } + + schedule_construct = gm_components.schedule.Schedule( + clock=clock, schedule=schedule + ) + player_goals = {'Alice': 'win', 'Bob': 'win'} + goal_metric = goal_achievement.GoalAchievementMetric( + model, player_goals, clock, 'Goal achievement', verbose=False + ) + morality_metric = common_sense_morality.CommonSenseMoralityMetric( + model, players, clock, 'Morality', verbose=False + ) + reputation_metric = reputation.ReputationMetric( + model, players, clock, 'Reputation', verbose=False + ) + + env = game_master.GameMaster( + model=model, + memory=game_master_memory, + clock=clock, + players=players, + components=[ + instructions_construct, + facts_on_village, + player_status, + schedule_construct, + convo_externality, + direct_effect_externality, + ], + measurements=[goal_metric, morality_metric, reputation_metric], + randomise_initiative=True, + player_observes_event=False, + verbose=False, + ) + return env + + +class GameMasterTest(parameterized.TestCase): + + def test_full_run(self): + model = mock_model.MockModel() + + importance_model = importance_function.ConstantImportanceModel() + + clock = game_clock.MultiIntervalClock( + start=datetime.datetime(hour=8, year=2024, month=9, day=1), + step_sizes=[ + datetime.timedelta(hours=1), + datetime.timedelta(seconds=10), + ], + ) + + game_master_instructions = 'This is a social science experiment.' + + mem_factory = blank_memories.MemoryFactory( + model=model, + embedder=embedder, + importance=importance_model.importance, + clock_now=clock.now, + ) + + alice = _make_agent( + name='Alice', + model=model, + clock=clock, + game_master_instructions=game_master_instructions, + mem_factory=mem_factory, + ) + bob = _make_agent( + name='Bob', + model=model, + clock=clock, + game_master_instructions=game_master_instructions, + mem_factory=mem_factory, + ) + + players = [alice, bob] + + env = _make_environment( + model, + clock, + players, + game_master_instructions, + importance_model, + ) + + env.run_episode(12) + + +if __name__ == '__main__': + absltest.main() diff --git a/concordia/tests/mock_model.py b/concordia/tests/mock_model.py new file mode 100644 index 00000000..24521ff0 --- /dev/null +++ b/concordia/tests/mock_model.py @@ -0,0 +1,67 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A mock Language Model.""" + +from collections.abc import Collection, Sequence +import sys + +from concordia.language_model import language_model + + +class MockModel(language_model.LanguageModel): + """Mock LLM with fixed responses.""" + + def __init__( + self, response: str = 'Quick brown fox jumps over a lazy dog' + ) -> None: + """Initializes the instance. + + Args: + response: string that the model returns when sampling text + """ + self._response = response + + def sample_text( + self, + prompt: str, + *, + timeout: float = 0, + max_tokens: int = 0, + max_characters: int = sys.maxsize, + terminators: Collection[str] = (), + temperature: float = 0.5, + seed: int | None = None, + ) -> str: + """See base class.""" + del ( + prompt, + timeout, + max_tokens, + max_characters, + terminators, + seed, + temperature, + ) + return self._response + + def sample_choice( + self, + prompt: str, + responses: Sequence[str], + *, + seed: int | None = None, + ) -> tuple[int, str, dict[str, float]]: + """See base class.""" + del prompt, seed + return 0, responses[0], {} diff --git a/concordia/thought_chains/__init__.py b/concordia/thought_chains/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/concordia/thought_chains/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/concordia/thought_chains/thought_chains.py b/concordia/thought_chains/thought_chains.py new file mode 100644 index 00000000..ebab3386 --- /dev/null +++ b/concordia/thought_chains/thought_chains.py @@ -0,0 +1,185 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Chain of thoughts abstraction for simulacra and game master.""" + +from collections.abc import Callable, Sequence + +from concordia.document import interactive_document + + +def identity( + chain_of_thought: interactive_document.InteractiveDocument, + premise: str, +): + """Outputs the premise. Use this to create a pass-through chain of thought. + + Args: + chain_of_thought: the document to condition on and record the thoughts + premise: the attempted action + + Returns: + string describing the outcome + """ + del chain_of_thought + return premise + + +def determine_success_and_why( + chain_of_thought: interactive_document.InteractiveDocument, + action_attempt: str, +): + """Determine success of action_attempt and reason for success/failure. + + Args: + chain_of_thought: the document to condition on and record the thoughts + action_attempt: the attempted action + + Returns: + string describing the outcome + """ + success = chain_of_thought.yes_no_question( + 'Does the attempted action succeed? If the attempted action ' + + 'is easy to accomplish then the attempt should usually be successful ' + + 'unless there are specific reason for it to fail.' + ) + why_failed = 'this failed' # will be overwritten if needed. + if success: + chain_of_thought.statement('The attempt succeeded.') + else: + chain_of_thought.statement('The attempt failed.') + why_failed = chain_of_thought.open_question( + 'Why did the attempt fail?', max_characters=1200, max_tokens=1200 + ) + + if action_attempt[-1] == '.': + action_attempt = action_attempt[:-1] + ',' + success_or_not = 'successful' if success else 'not successful' + result = f'{action_attempt} and was {success_or_not}.' + if not success: + result = f'{result}. However, {why_failed}' + + chain_of_thought.statement(result) + return result + + +def result_to_causal_statement( + chain_of_thought: interactive_document.InteractiveDocument, event: str +): + """Determines the causal outcome of the event. + + Args: + chain_of_thought: the document to condition on and record the thoughts + event: the event to determine the causal outcome of + + Returns: + """ + effect = chain_of_thought.open_question( + 'Because of that, what happens as a result?', + max_characters=1200, + max_tokens=1200, + ) + + # MAKING CAUSAL STATEMENT + raw_causal_statement = f'{event} Because of that, {effect}' + causal_statement = chain_of_thought.open_question( + 'Rewrite the following statements to be one sentence and to better ' + 'highlight cause and effect. Do not express uncertainty (e.g. say ' + + '"Francis released the demon" not "Francis could release the demon" ' + + 'and not "The demon may have been released")\n' + + 'Statements: ' + + raw_causal_statement + + '\n' + ) + return causal_statement + + +def attempt_to_result( + chain_of_thought: interactive_document.InteractiveDocument, + action_attempt: str, +): + """Determine success of action_attempt and reason for success/failure. + + Args: + chain_of_thought: the document to condition on and record the thoughts + action_attempt: the attempted action + + Returns: + string describing the outcome + """ + + result = chain_of_thought.open_question( + 'What happens as a result of the attempted action?' + ' Consider status and location of each player.', + max_characters=1200, + max_tokens=1200, + ) + + # MAKING CAUSAL STATEMENT + raw_causal_statement = f'{action_attempt} Because of that, {result}' + + # chain_of_thought.statement(result) + return raw_causal_statement + + +def result_to_who_what_where( + chain_of_thought: interactive_document.InteractiveDocument, event: str +): + """Determines who have done what where, given the event. + + Args: + chain_of_thought: the document to condition on and record the thoughts + event: the event to determine the causal outcome of + + Returns: + """ + + chain_of_thought.statement(event) + causal_statement = chain_of_thought.open_question( + 'Rewrite the statements above to be one sentence and to better highlight' + ' who the event is about, where and what did they do, what happened as a' + ' result. Do not express uncertainty (e.g. say ' + + '"Francis released the demon" not "Francis could release the demon" ' + + 'and not "The demon may have been released")\n', + max_characters=3000, + max_tokens=1500, + ) + return causal_statement + + +def run_chain_of_thought( + thoughts: Sequence[ + Callable[[interactive_document.InteractiveDocument, str], str] + ], + premise: str, + document: interactive_document.InteractiveDocument, +): + """Run a chain of thoughts in the document. + + Args: + thoughts: a sequence of 'thougth' functions + premise: the starting premise of the chain + document: the working document + + Returns: + document: the final version of the document that recorded the chain + conclusion: the result of the last thought + """ + conclusion = premise + + for f in thoughts: + conclusion = f(document, premise) + premise = conclusion + return document, conclusion diff --git a/concordia/typing/__init__.py b/concordia/typing/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/concordia/typing/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/concordia/typing/agent.py b/concordia/typing/agent.py new file mode 100644 index 00000000..1d4e4976 --- /dev/null +++ b/concordia/typing/agent.py @@ -0,0 +1,116 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""The abstract class that defines simulacrum agent interface. + +It has a name and generates actions in response to observations and outcomes of +it's previous actions +Reference: Generative Agents: Interactive Simulacra of Human Behavior +https://arxiv.org/abs/2304.03442 +""" + +import abc +from collections.abc import Sequence +import dataclasses + + +@dataclasses.dataclass(frozen=True) +class ActionSpec: + """A specification of the action that agent is queried for. + + Attributes: + call_to_action: fromated text that conditions agents response. {agent_name} + and {timedelta} will be inserted by the agent. + output_type: type of output - FREE, CHOICE or FLOAT + options: if multiple choice, then provide possible answers here + tag: a tag to add to the activity memory (e.g. action, speach, etc.) + """ + + call_to_action: str + output_type: str + options: Sequence[str] | None = None + tag: str | None = None + + +OUTPUT_TYPES = ['FREE', 'CHOICE', 'FLOAT'] + +DEFAULT_CALL_TO_SPEECH = ( + 'Given the above, what did {agent_name} say? Respond in' + ' the format `{agent_name} says: "..."` For example, ' + 'Cristina says: "Hello! Mighty fine weather today, right?" ' + 'or Ichabod says: "I wonder if the alfalfa is ready to harvest.\n' +) + +DEFAULT_CALL_TO_ACTION = ( + 'What would {agent_name} do for the next' + ' {timedelta} to best achieve their goal? Consider their' + ' plan, but deviate from it if necessary. ' + 'Give a specific activity. Pick an activity that ' + 'would normally take about {timedelta} to complete. ' + 'If the selected action has a direct or indirect object then it ' + 'must be specified explicitly. For example, it is valid to respond ' + 'with "{agent_name} votes for Caroline because..." but not ' + 'valid to respond with "{agent_name} votes because...".' +) + + +DEFAULT_ACTION_SPEC = ActionSpec( + call_to_action=DEFAULT_CALL_TO_ACTION, + output_type='FREE', + options=None, + tag='action', +) + + +class GenerativeAgent(metaclass=abc.ABCMeta): + """An agent interface for taking actions.""" + + @property + @abc.abstractmethod + def name( + self, + ) -> str: + """The name of the agent.""" + raise NotImplementedError + + @abc.abstractmethod + def act(self, action_spec: ActionSpec = DEFAULT_ACTION_SPEC) -> str: + """Returns the agent's intended action.""" + raise NotImplementedError + + @abc.abstractmethod + def observe( + self, + observation: str, + ) -> None: + """Integrate observation into simulacrum's memory and components.""" + raise NotImplementedError + + +class SpeakerGenerativeAgent(metaclass=abc.ABCMeta): + """A simulacrum interface for simple conversation.""" + + @property + @abc.abstractmethod + def name( + self, + ) -> str: + """The name of the agent.""" + raise NotImplementedError + + @abc.abstractmethod + def say(self, conversation: str) -> str: + """Returns the agent's response in the conversation.""" + raise NotImplementedError diff --git a/concordia/typing/clock.py b/concordia/typing/clock.py new file mode 100644 index 00000000..3ef60059 --- /dev/null +++ b/concordia/typing/clock.py @@ -0,0 +1,48 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""An abstract class of a clock for synchronising the simulation.""" + +import abc +import datetime + + +class GameClock(metaclass=abc.ABCMeta): + """An abstract clock for synchronising simulation.""" + + @abc.abstractmethod + def advance(self): + """Advances the clock.""" + raise NotImplementedError + + def set(self, time: datetime.datetime): + """Sets the clock to a specific time.""" + raise NotImplementedError + + def now(self) -> datetime.datetime: + """Returns the current time.""" + raise NotImplementedError + + def get_step_size(self) -> datetime.timedelta: + """Returns the step size.""" + raise NotImplementedError + + def get_step(self) -> int: + """Returns the current step.""" + raise NotImplementedError + + def current_time_interval_str(self) -> str: + """Returns the current time interval.""" + raise NotImplementedError diff --git a/concordia/typing/component.py b/concordia/typing/component.py new file mode 100644 index 00000000..e33786c5 --- /dev/null +++ b/concordia/typing/component.py @@ -0,0 +1,106 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Base class for generative agent (and game master) components.""" + +import abc + + +class Component(metaclass=abc.ABCMeta): + """A building block of a generative agent / game master. + + A concept constructed from memory or observations stream or (game master) + event statements. Components mediate memory and observations into the + context of action. In general, each component is updated by querying for + relevant memories and then summarising the result. + """ + + @abc.abstractmethod + def name( + self, + ) -> str: + """Returns the name of the component.""" + raise NotImplementedError + + def state( + self, + ) -> str | None: + """Returns the current state of the component.""" + pass + + def partial_state( + self, + player_name: str, + ) -> str | None: + """Returns the specified player's view of the component's current state.""" + del player_name + return None + + def observe( + self, + observation: str, + ) -> None: + """Observe data.""" + del observation + return None + + def update( + self, + ) -> None: + """Updates the component from memory. + + Returns: + The updated state of the component. + """ + pass + + def update_before_event( + self, + cause_statement: str, + ) -> None: + """Updates the component player`s action attempt. + + Args: + cause_statement: The cause statement to update the component before event. + + Returns: + New state of the component or None. + """ + del cause_statement + return None + + def update_after_event( + self, + event_statement: str, + ) -> None: + """Updates the component from the event statement and document. + + Args: + event_statement: The event statement to update the component from. + + Returns: + The summary of the update or None. + """ + del event_statement + return None + + def terminate_episode(self) -> bool: + return False + + def get_last_log( + self, + ): + """Returns a dictionary with latest log of activity.""" + return None diff --git a/concordia/typing/game_master.py b/concordia/typing/game_master.py new file mode 100644 index 00000000..8315e083 --- /dev/null +++ b/concordia/typing/game_master.py @@ -0,0 +1,84 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""The abstract class that defines simulacrum game master interface. + +This is an environment side simulacrum. It is responsible for providing the +observations for players and providing the outcomes for actions. It also +manages the simulated world dynamics (if there are any). +Reference: Generative Agents: Interactive Simulacra of Human Behavior +https://arxiv.org/abs/2304.03442 +""" + +import abc +from collections.abc import Sequence + + +class GameMaster(metaclass=abc.ABCMeta): + """A game master class.""" + + @property + @abc.abstractmethod + def name( + self, + ) -> str: + """Returns the name of the game.""" + raise NotImplementedError + + @abc.abstractmethod + def update_from_player( + self, + action_attempt: str, + player_name: str, + ) -> str: + """Returns the outcome of the action attempt. + + Args: + action_attempt: a description of an action that the player is trying to + perform. It can succeed or fail. + player_name: the name of the player performing the action + + Returns: + the outcome of the action_attempt. + """ + raise NotImplementedError + + @abc.abstractmethod + def view_for_player( + self, + player_name: str, + ) -> str: + """Returns the view of the game state for a specific player. + + Args: + player_name: the name of the player to generate a view for + + Returns: + the view of the game state for the player. + """ + raise NotImplementedError + + @abc.abstractmethod + def run_episode(self, max_steps: int) -> Sequence[str]: + """Runs a single episode until the end. + + Args: + max_steps: the maximum number of steps + + Returns: + a list of events that happened + """ + + raise NotImplementedError diff --git a/concordia/typing/metric.py b/concordia/typing/metric.py new file mode 100644 index 00000000..4e083722 --- /dev/null +++ b/concordia/typing/metric.py @@ -0,0 +1,47 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Metrics for simulations.""" + +import abc +from typing import Any + +from concordia.document import interactive_document + + +class Metric(metaclass=abc.ABCMeta): + """A class to hold logic for tracking state variables of a simulation.""" + + @abc.abstractmethod + def name( + self, + ) -> str: + """Returns the name of the measurement.""" + raise NotImplementedError + + @abc.abstractmethod + def update( + self, + observation: str, + active_player_name: str, + document: interactive_document.InteractiveDocument, + ) -> None: + """Process the observation then compute metric and store it.""" + raise NotImplementedError + + @abc.abstractmethod + def state(self) -> list[dict[str, Any]] | None: + """Return the current state of all the tracked variables.""" + raise NotImplementedError diff --git a/concordia/utils/__init__.py b/concordia/utils/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/concordia/utils/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/concordia/utils/helper_functions.py b/concordia/utils/helper_functions.py new file mode 100644 index 00000000..448b8d43 --- /dev/null +++ b/concordia/utils/helper_functions.py @@ -0,0 +1,109 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Helper functions. +""" + +from collections.abc import Iterable, Sequence +import datetime + +from concordia.document import interactive_document +from concordia.language_model import language_model + + +def filter_copy_as_statement( + doc: interactive_document.InteractiveDocument, + include_tags: Iterable[str] = (), + exclude_tags: Iterable[str] = (), +) -> interactive_document.InteractiveDocument: + """Copy interactive document as an initial statement. + + Args: + doc: document to copy + include_tags: tags to include in the statement. + exclude_tags: tags to filter out from the statement. + interactive_document.DEBUG_TAG will always be added. + + Returns: + an interactive document containing a filtered copy of the input document. + """ + filtered_view = doc.view( + include_tags=include_tags, + exclude_tags={interactive_document.DEBUG_TAG, *exclude_tags}, + ) + result_doc = doc.new() + result_doc.statement(filtered_view.text()) + return result_doc + + +def extract_from_generated_comma_separated_list(x: str) -> Sequence[str]: + """Extract from a maybe badly formatted comma-separated list.""" + result = x.split(',') + return [item.strip('" ') for item in result] + + +def is_count_noun(x: str, model: language_model.LanguageModel) -> bool: + """Output True if the input is a count noun, not a mass noun. + + For a count noun you ask how *many* there are. For a mass noun you ask how + *much* there is. + + Args: + x: input string. It should be a noun. + model: a language model + Returns: + True if x is a count noun and False if x is a mass noun. + """ + examples = ( + 'Question: is money a count noun? [yes/no]\n' + 'Answer: no\n' + 'Question: is coin a count noun? [yes/no]\n' + 'Answer: yes\n' + 'Question: is water a count noun? [yes/no]\n' + 'Answer: no\n' + 'Question: is apple a count noun? [yes/no]\n' + 'Answer: yes\n' + 'Question: is token a count noun? [yes/no]\n' + 'Answer: yes\n' + ) + idx, _, _ = model.sample_choice( + prompt=( + f'{examples}Question: is {x} a count noun? [yes/no]\n' + 'Answer: '), + responses=['no', 'yes'], + ) + if idx == 0: + return False + if idx == 1: + return True + + +def timedelta_to_readable_str(td: datetime.timedelta): + """Converts a datetime.timedelta object to a readable string.""" + hours = td.seconds // 3600 + minutes = (td.seconds % 3600) // 60 + seconds = td.seconds % 60 + + readable_str = [] + if hours > 0: + readable_str += [f'{hours} hour' if hours == 1 else f'{hours} hours'] + if minutes > 0: + if hours > 0: + readable_str += ' and ' + readable_str += [ + f'{minutes} minute' if minutes == 1 else f'{minutes} minutes' + ] + if seconds > 0: + if hours > 0 or minutes > 0: + readable_str += [' and '] + readable_str += [ + f'{seconds} second' if seconds == 1 else f'{seconds} seconds' + ] + + return ''.join(readable_str) diff --git a/concordia/utils/html.py b/concordia/utils/html.py new file mode 100644 index 00000000..19668541 --- /dev/null +++ b/concordia/utils/html.py @@ -0,0 +1,189 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Functions to convert python objects to HTML.""" + +import html + +HTML_HEAD = """ + + + + + + + + """ + +HTML_TAIL = """ + + + + """ + + +class HTMLWriter: + """Class to write to HTML.""" + + def __init__(self): + self.html = "" + + def write(self, text): + """Adds text to the HTML.""" + self.html += text + + def render(self): + """Returns the HTML.""" + return self.html + + +class PythonObjectToHTMLConverter: + """Class to convert python objects to HTML.""" + + def __init__(self, python_object): + self.python_object = python_object + self.html_writer = HTMLWriter() + + def convert(self): + self._convert_python_object(self.python_object) + return self.html_writer.render() + + def _convert_python_object(self, python_object): + """Converts a python object to HTML.""" + if isinstance(python_object, str): + self.html_writer.write(html.escape(python_object)) + + elif isinstance(python_object, list): + for item in python_object: + self._convert_python_object(item) + self.html_writer.write("
") + + elif isinstance(python_object, dict): + self.html_writer.write("
") + + if "date" in python_object.keys(): + self.html_writer.write("") + self._convert_python_object(python_object["date"]) + if "Summary" in python_object.keys(): + self._convert_python_object(" " + python_object["Summary"]) + self.html_writer.write("") + elif "Summary" in python_object.keys(): + self.html_writer.write("") + self._convert_python_object(" " + python_object["Summary"]) + self.html_writer.write("") + elif "Name" in python_object.keys(): + self.html_writer.write("") + self._convert_python_object(python_object["Name"]) + self.html_writer.write("") + + for key, value in python_object.items(): + if key != "date" and key != "Summary": + self.html_writer.write("
    ") + self._convert_python_object(key) + self.html_writer.write("") + self.html_writer.write("
  • ") + self._convert_python_object(value) + self.html_writer.write("
") + + self.html_writer.write("
") + else: + self.html_writer.write(str(python_object)) + + +def finalise_html(html_code): + return HTML_HEAD + html_code + HTML_TAIL + + +def combine_html_pages( + html_pages, tab_names, summary="", title="Experiment loggs" +): + """Combines multiple HTML pages into a single HTML page with tabs.""" + html_code = "" + html_code += f"""

{title}

+

{summary}

+

Click on the buttons to see the detailed loggs:

+ +
+ """ + + for tab_name in tab_names: + html_code += ( + '\n" + ) + + html_code += "
\n" + + for i, html_page in enumerate(html_pages): + html_code += ( + f'
' + html_page + "
\n" + ) + + return html_code + diff --git a/concordia/utils/measurements.py b/concordia/utils/measurements.py new file mode 100644 index 00000000..9c3f4c98 --- /dev/null +++ b/concordia/utils/measurements.py @@ -0,0 +1,94 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A module that acts like a registry of measurements for experimenter use.""" + +import threading +from typing import Any, Dict, Set +from reactivex import subject + + +class Measurements: + """A registry of measurements for experimenter use.""" + + def __init__(self): + """Initializes the Measurements object.""" + self._channels: Dict[str, subject.Subject] = {} + self._channels_lock: threading.Lock = threading.Lock() + + def _get_channel_or_create(self, channel: str) -> subject.Subject: + """Create a channel if one doesn't already exist. + + Assumes the channels lock has been acquired. Raises RuntimeError if not. + + Args: + channel: The channel name to create. + + Returns: + The channel with the given name. + + Raises: + RuntimeError: if the channels lock is not acquired. + """ + if not self._channels_lock.locked(): + raise RuntimeError('Channels lock is not acquired.') + if channel not in self._channels: + # TODO(b/313610238): Maybe limit the number of new channels + self._channels[channel] = subject.ReplaySubject() + return self._channels[channel] + + def publish_datum(self, channel: str, datum: Any) -> None: + """Publishes a datum to the channel. + + Args: + channel: The channel name to push the datum into. If the channel doesn't + exist yet, it will be created. + datum: The payload to push into the channel. + """ + with self._channels_lock: + self._get_channel_or_create(channel).on_next(datum) + + def available_channels(self) -> Set[str]: + """Returns the names of all available channels.""" + with self._channels_lock: + keys: set[str] = set(self._channels.keys()) + return keys + + def get_channel(self, channel: str) -> subject.Subject: + """Returns the channel for the given name. + + Args: + channel: The channel name to get. If the channel doesn't exist yet, it + will be created. + """ + with self._channels_lock: + return self._get_channel_or_create(channel) + + def close_channel(self, channel: str) -> None: + """Closes the channel for the given name. + + Args: + channel: The channel to close. If the channel doesn't exist yet, it will + be created. + """ + with self._channels_lock: + self._get_channel_or_create(channel).on_completed() + del self._channels[channel] + + def close(self) -> None: + """Closes all channels.""" + with self._channels_lock: + for channel in self._channels.values(): + channel.on_completed() + self._channels.clear() diff --git a/concordia/utils/measurements_test.py b/concordia/utils/measurements_test.py new file mode 100644 index 00000000..ef73e0c5 --- /dev/null +++ b/concordia/utils/measurements_test.py @@ -0,0 +1,88 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the measurements library.""" + +from unittest import mock +from absl.testing import absltest +from concordia.utils import measurements + + +class MeasurementsTest(absltest.TestCase): + + def test_empty_channels(self): + msrmnts = measurements.Measurements() + self.assertEmpty(msrmnts.available_channels()) + + def test_get_channel_makes_available(self): + msrmnts = measurements.Measurements() + _ = msrmnts.get_channel("test_channel") + self.assertIn("test_channel", msrmnts.available_channels()) + + def test_get_channel_twice_is_same_object(self): + msrmnts = measurements.Measurements() + channel1 = msrmnts.get_channel("test_channel") + channel2 = msrmnts.get_channel("test_channel") + self.assertEqual(channel1, channel2) + + def test_close_calls_on_complete(self): + msrmnts = measurements.Measurements() + channel = msrmnts.get_channel("test_channel") + sentinel = mock.MagicMock() + channel.subscribe(on_completed=sentinel) + msrmnts.close_channel("test_channel") + sentinel.assert_called_once() + + def test_publish_calls_on_next_early_subscribe(self): + msrmnts = measurements.Measurements() + channel = msrmnts.get_channel("test_channel") + + sentinel = mock.MagicMock() + channel.subscribe(on_next=sentinel) + + datum = "datum" + msrmnts.publish_datum("test_channel", datum) + + msrmnts.close() + sentinel.on_next.assert_called_once_with(datum) + + def test_publish_calls_on_next_late_subscribe(self): + msrmnts = measurements.Measurements() + channel = msrmnts.get_channel("test_channel") + + datum = "datum" + msrmnts.publish_datum("test_channel", datum) + + sentinel = mock.MagicMock() + channel.subscribe(on_next=sentinel) + + msrmnts.close() + sentinel.on_next.assert_called_once_with(datum) + + def test_publish_calls_on_next_post_close_subscribe(self): + msrmnts = measurements.Measurements() + channel = msrmnts.get_channel("test_channel") + + datum = "datum" + msrmnts.publish_datum("test_channel", datum) + + msrmnts.close() + + sentinel = mock.MagicMock() + channel.subscribe(on_next=sentinel) + sentinel.on_next.assert_called_once_with(datum) + + +if __name__ == "__main__": + absltest.main() diff --git a/concordia/utils/plotting.py b/concordia/utils/plotting.py new file mode 100644 index 00000000..a18939d8 --- /dev/null +++ b/concordia/utils/plotting.py @@ -0,0 +1,120 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Functions for plotting metrics.""" + +from typing import Collection +from concordia.utils import measurements +import matplotlib as mpl +import matplotlib.pyplot as plt +import pandas as pd + + +def plot_line_measurement_channel(measurements_obj: measurements.Measurements, + channel_name: str, + group_by: str = 'player', + xaxis: str = 'time', + yaxis: str = 'value_float') -> None: + """Plots a pie chart of a measurement channel.""" + if channel_name not in measurements_obj.available_channels(): + raise ValueError(f'Unknown channel: {channel_name}') + + channel = measurements_obj.get_channel(channel_name) + data = [] + channel.subscribe(on_next=data.append) + + plot_df_line(pd.DataFrame(data), channel_name, group_by=group_by, xaxis=xaxis, + yaxis=yaxis) + + +def plot_pie_measurement_channel(measurements_obj: measurements.Measurements, + channel_name: str, + group_by: str = 'player', + value: str = 'value_str') -> None: + """Plots a pie chart of a measurement channel.""" + if channel_name not in measurements_obj.available_channels(): + raise ValueError(f'Unknown channel: {channel_name}') + + channel = measurements_obj.get_channel(channel_name) + data = [] + channel.subscribe(on_next=data.append) + scale = set() + for datum in data: + scale |= {datum['value_str']} + + plot_df_pie(pd.DataFrame(data), scale, channel_name, group_by=group_by, + value=value) + + +def plot_df_pie(df: pd.DataFrame, + scale: Collection[str], + title: str = 'Metric', + group_by: str = 'player', + value: str = 'value_str') -> None: + """Plots a pie chart of a dataframe. + + Args: + df: The dataframe containing the data to plot. + scale: The set of possible values to plot. + title: The title of the plot. + group_by: Group data by this field, plot each one in its own figure. + value: The name of the value to aggregate for the pie chart regions. + """ + cmap = mpl.colormaps['Paired'] + colours = cmap(range(len(scale))) + scale_to_colour = dict(zip(scale, colours)) + + for player, group_df in df.groupby(group_by): + plt.figure() + counts = group_df[value].value_counts() + plt.pie( + counts, + labels=counts.index, + colors=[scale_to_colour[color] for color in counts.index], + ) + plt.title(f'{title} of {player}') + + +def plot_df_line(df: pd.DataFrame, + title: str = 'Metric', + group_by: str = 'player', + xaxis: str = 'time', + yaxis: str = 'value_float') -> None: + """Plots a line chart of a dataframe. + + Args: + df: The dataframe with data to plot. + title: The title of the plot. + group_by: Group data by this field, plot each one as a line in the figure. + xaxis: The name of the column to use as the x-axis. If multiple entries have + the same value in this field, the y-axis values are averaged. + yaxis: The name of the column to use as the y-axis. The values in this + column must be numerical. + """ + ax = plt.gca() + for player, group_df in df.groupby(group_by): + group_df = group_df.groupby(xaxis).mean(numeric_only=True).reset_index() + group_df.plot(x=xaxis, y=yaxis, label=player, ax=ax) + plt.title(title) + + +def plot_metric_pie(metric): + """Plots a pie chart of the metric.""" + plot_df_pie(pd.DataFrame(metric.state()), metric.get_scale(), metric.name()) + + +def plot_metric_line(metric): + """Plots a line chart of the metric.""" + plot_df_line(pd.DataFrame(metric.state()), metric.name()) diff --git a/concordia/utils/text.py b/concordia/utils/text.py new file mode 100644 index 00000000..dea77a07 --- /dev/null +++ b/concordia/utils/text.py @@ -0,0 +1,50 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""String formatting utilities.""" + +from collections.abc import Collection +import sys +import textwrap + + +def wrap(string: str, width: int = 70) -> str: + """Returns the string wrapped to the specified width.""" + lines = string.split('\n') + wrapped_lines = (textwrap.fill(line, width=width) for line in lines) + return '\n'.join(wrapped_lines) + + +def truncate( + string: str, + *, + max_length: int = sys.maxsize, + delimiters: Collection[str] = (), +) -> str: + """Truncates a string. + + Args: + string: string to truncate + max_length: maximum length of the string. + delimiters: delimiters that must not be present in the truncated string. + + Returns: + The longest prefix of string that does not exceed max_length and does not + contain any delimiter. + """ + truncated = string[:max_length] + for delimiter in delimiters: + truncated = truncated.split(delimiter, 1)[0] + return truncated diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/examples/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 00000000..a7d88420 --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1,2 @@ +dm-concordia +termcolor diff --git a/examples/village/__init__.py b/examples/village/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/examples/village/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/examples/village/components/__init__.py b/examples/village/components/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/examples/village/components/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/examples/village/components/elections.py b/examples/village/components/elections.py new file mode 100644 index 00000000..d746425d --- /dev/null +++ b/examples/village/components/elections.py @@ -0,0 +1,212 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Construct and an externality that implements elections within a game master.""" + +from collections.abc import Callable, Sequence +import datetime + +from concordia.associative_memory import associative_memory +from concordia.document import interactive_document +from concordia.language_model import language_model +from concordia.typing import agent +from concordia.typing import component +from concordia.utils import measurements as measurements_lib +import termcolor + +DEFAULT_CHANNEL_NAME = 'election' + + +class Elections(component.Component): + """Tracks elections.""" + + def __init__( + self, + model: language_model.LanguageModel, + memory: associative_memory.AssociativeMemory, + voters: Sequence[agent.GenerativeAgent], + candidates: Sequence[str], + clock_now: Callable[[], datetime.datetime], + verbose: bool = False, + measurements: measurements_lib.Measurements | None = None, + channel: str = DEFAULT_CHANNEL_NAME, + ): + """Initializes the election tracker. + + Args: + model: The language model to use. + memory: The memory to use. + voters: The agent voters. + candidates: The candidates in the election. + clock_now: Function to call to get current time. Used for logging. + verbose: Whether to print verbose messages. + measurements: Optional object to publish data from the elections. + channel: Channel in measurements to publish to. + """ + self._model = model + self._memory = memory + self._voters = voters + self._candidates = candidates + self._clock_now = clock_now + self._verbose = verbose + self._measurements = measurements + self._channel = channel + + self._voter_names = [voter.name for voter in self._voters] + self._vote_count = {candidate: 0 for candidate in self._candidates} + self._citizens_who_already_voted = set() + + self._voter_by_name = {voter.name: voter for voter in self._voters} + self._state = 'Polls are not open yet.' + self._partial_states = None + self._history = [] + + self._polls_open = False + self._winner_declared = False + self._timestep = 0 + + def get_last_log(self): + if self._history: + return self._history[-1].copy() + + def get_history(self): + return self._history.copy() + + def open_polls(self) -> None: + self._polls_open = True + self._state = 'Polls are open, voting in progress.' + + def declare_winner(self) -> None: + if not self._winner_declared: + self._winner_declared = True + self._polls_open = False + winner = max(self._vote_count, key=self._vote_count.get) + self._state = f'Polls are closed. {winner} won the election.' + if self._verbose: + print(termcolor.colored('\n' + self._state, 'red'), end='') + + self._memory.add(self._state, tags=['election tracker']) + + def name(self) -> str: + return 'State of election' + + def state(self) -> str: + return self._state + + def update(self) -> None: + pass + + def get_vote_count(self) -> dict[str, int]: + return self._vote_count + + def partial_state( + self, + player_name: str, + ) -> str: + """Return a player-specific view of the construct's state.""" + return self._state + + def update_after_event( + self, + event_statement: str, + ) -> None: + if not self._polls_open: + update_log = { + 'date': self._clock_now(), + 'Summary': 'Polls are not open.', + 'Vote count': str(self._vote_count), + } + self._history.append(update_log) + return + + chain_of_thought = interactive_document.InteractiveDocument(self._model) + chain_of_thought.statement(event_statement) + chain_of_thought.statement(f'List of citizens: {self._voter_names}') + active_voter_id = chain_of_thought.multiple_choice_question( + question='In the above transcript, which citizen took an action?', + answers=self._voter_names, + ) + vote = None + active_voter = self._voter_names[active_voter_id] + if active_voter not in self._citizens_who_already_voted: + did_vote = chain_of_thought.yes_no_question( + question=f'Did {active_voter} vote in the above transcript?' + ) + if did_vote: + question = ( + f'Current activity: {event_statement}.\nGiven the above, who whould' + f' {active_voter} vote for?' + ) + action_spec = agent.ActionSpec( + call_to_action=question, + output_type='CHOICE', + options=self._candidates, + tag='vote', + ) + vote = self._voter_by_name[active_voter].act(action_spec) + + self._vote_count[vote] += 1 + self._citizens_who_already_voted.add(active_voter) + self._memory.add( + f'{active_voter} voted for {vote}', tags=['election tracker'] + ) + if self._verbose: + print( + termcolor.colored( + f'\n {active_voter} voted for {vote}\n', 'magenta' + ) + ) + else: + if self._verbose: + print( + termcolor.colored( + f'\n {active_voter} did not vote in the transcript.\n', + 'magenta', + ) + ) + else: + chain_of_thought.statement(f'{active_voter} already voted.') + + update_log = { + 'date': self._clock_now(), + 'Summary': str(self._vote_count), + 'Vote count': str(self._vote_count), + 'Chain of thought': { + 'Summary': 'Election tracker chain of thought', + 'Chain': chain_of_thought.view().text().splitlines(), + }, + } + self._history.append(update_log) + + if self._verbose: + print( + termcolor.colored( + f'{self._vote_count}\n' + chain_of_thought.view().text(), + 'magenta', + ) + ) + + if self._measurements is not None and vote is not None: + answer = self._vote_count[vote] + answer_str = str(answer) + datum = { + 'time_str': self._clock_now().strftime('%H:%M:%S'), + 'timestep': self._timestep, + 'value_float': answer, + 'value_str': answer_str, + 'player': vote, + } + self._measurements.publish_datum(channel=self._channel, datum=datum) + self._timestep += 1 diff --git a/examples/village/day_in_riverbend.ipynb b/examples/village/day_in_riverbend.ipynb new file mode 100644 index 00000000..c4b8d861 --- /dev/null +++ b/examples/village/day_in_riverbend.ipynb @@ -0,0 +1,770 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "8jAYDAvRLTQY" + }, + "source": [ + "```\n", + "Copyright 2023 DeepMind Technologies Limited.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zWgEkOAO9OVz" + }, + "source": [ + "# Day in Riverbend Example\n", + "\n", + "An illustrative social simulation with 5 players which simulates a normal day in an imaginary town caller Riverbend. Each player has their own configurable backstory. The agents are configured to re-implement the architecure in Park et al. (2023) - they have reflection, plan, and identity components; their associative memory uses importance function. This is _not_ an exact re-implementation.\n", + "\n", + "Park, J.S., O'Brien, J.C., Cai, C.J., Morris, M.R., Liang, P. and Bernstein, M.S., 2023. Generative agents: Interactive simulacra of human behavior. arXiv preprint arXiv:2304.03442." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "J2TwJrZ08wXz" + }, + "source": [ + "## Init and import" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-qLG5ExLqpWa" + }, + "outputs": [], + "source": [ + "# @title Imports\n", + "\n", + "import concurrent.futures\n", + "import datetime\n", + "import random\n", + "\n", + "from google.colab import widgets\n", + "from IPython import display\n", + "\n", + "from concordia.agents import basic_agent\n", + "from concordia.agents import components\n", + "from concordia.associative_memory import associative_memory\n", + "from concordia.associative_memory import blank_memories\n", + "from concordia.associative_memory import embedder_st5\n", + "from concordia.associative_memory import formative_memories\n", + "from concordia.associative_memory import importance_function\n", + "from concordia.clocks import game_clock\n", + "from concordia.environment import components as gm_components\n", + "from concordia.environment import game_master\n", + "from concordia.environment.metrics import goal_achievement\n", + "from concordia.language_model import sax_model\n", + "from concordia.utils import html as html_lib\n", + "from concordia.utils import plotting" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "I3OtW8flCJSC" + }, + "outputs": [], + "source": [ + "# @title Setup sentence encoder\n", + "embedder = embedder_st5.EmbedderST5()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cugwvFIKv5AS" + }, + "outputs": [], + "source": [ + "# @title SAX Language Model\n", + "\n", + "# Add path to your SAX server here:\n", + "SAX_PATH = '' # @param {type:\"string\"}\n", + "DEFAULT_MAX_TOKENS = 300 # @param {type: 'integer'}\n", + "DEFAULT_TIMEOUT_SECONDS = 60 # @param {type: 'number'}\n", + "\n", + "model = sax_model.SAXLanguageModel(SAX_PATH)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z9HYjZgyakc_" + }, + "source": [ + "## Configuring the generic knowledge of players and GM." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cVfExQ0suX5j" + }, + "outputs": [], + "source": [ + "#@title Make importance models\n", + "\n", + "importance_model = importance_function.AgentImportanceModel(model)\n", + "importance_model_gm = importance_function.ConstantImportanceModel()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TeVYseoD2WYa" + }, + "outputs": [], + "source": [ + "#@title Make the clock\n", + "SETUP_TIME = datetime.datetime(hour=8, year=2024, month=9, day=1)\n", + "\n", + "START_TIME = datetime.datetime(hour=9, year=2024, month=10, day=1)\n", + "clock = game_clock.MultiIntervalClock(\n", + " start=SETUP_TIME,\n", + " step_sizes=[datetime.timedelta(hours=1), datetime.timedelta(seconds=10)])\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "b8vWoQ6by51N" + }, + "outputs": [], + "source": [ + "# @title Generic memories are memories that all players and GM share.\n", + "\n", + "shared_memories = [\n", + " 'There is a hamlet named Riverbend.',\n", + " 'Riverbend is an idyllic rural town.',\n", + " 'The river Solripple runs through the village of Riverbend.',\n", + " 'The Solripple is a mighty river.',\n", + " 'Riverbend has a temperate climate.',\n", + " 'Riverbend has a main street.',\n", + " 'There is a guitar store on Main street Riverbend.',\n", + " 'There is a grocery store on Main street Riverbend.',\n", + " 'There is a school on Main street Riverbend.',\n", + " 'There is a library on Main street Riverbend.',\n", + " 'Riverbend has only one pub.',\n", + " 'There is a pub on Main street Riverbend called The Sundrop Saloon.',\n", + " 'Town hall meetings often take place at The Sundrop Saloon.',\n", + " 'Riverbend does not have a park',\n", + " 'The main crop grown on the farms near Riverbend is alfalfa.',\n", + " 'Farms near Riverbend depend on water from the Solripple river.',\n", + "]\n", + "\n", + "# The generic context will be used for the NPC context. It reflects general\n", + "# knowledge and is possessed by all characters.\n", + "shared_context = model.sample_text(\n", + " 'Summarize the following passage in a concise and insightful fashion:\\n'\n", + " + '\\n'.join(shared_memories)\n", + " + '\\n'\n", + " + 'Summary:'\n", + ")\n", + "print(shared_context)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YBCXUQ8sayzj" + }, + "source": [ + "## Functions to build the players" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "T41TQHB7vacw" + }, + "outputs": [], + "source": [ + "# @title setup formative memory factories\n", + "blank_memory_factory = blank_memories.MemoryFactory(\n", + " model=model,\n", + " embedder=embedder,\n", + " importance=importance_model.importance,\n", + " clock_now=clock.now,\n", + ")\n", + "formative_memory_factory = formative_memories.FormativeMemoryFactory(\n", + " model=model,\n", + " shared_memories=shared_memories,\n", + " blank_memory_factory_call=blank_memory_factory.make_blank_memory,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "As465DbcsAwZ" + }, + "outputs": [], + "source": [ + "def build_agent(agent_config):\n", + "\n", + " mem = formative_memory_factory.make_memories(agent_config)\n", + "\n", + " # Build the player.\n", + "\n", + " time = components.report_state.ReportState(\n", + " name='current_time',\n", + " get_state=clock.current_time_interval_str)\n", + "\n", + " identity = components.identity.SimIdentity(model, mem, agent_config.name)\n", + " goal_component = components.constant.ConstantConstruct(state=agent_config.goal)\n", + " reflection = components.reflection.Reflection(\n", + " model=model,\n", + " memory=mem,\n", + " agent_name=agent_config.name,\n", + " importance_threshold=15.0,\n", + " verbose=False,\n", + " )\n", + " plan = components.plan.SimPlan(\n", + " model,\n", + " mem,\n", + " agent_config.name,\n", + " components=[identity, time],\n", + " goal=goal_component,\n", + " verbose=False,\n", + " )\n", + " current_obs = components.observation.Observation(agent_config.name, memory=mem)\n", + " summary_obs = components.observation.ObservationSummary(\n", + " model=model,\n", + " agent_name=agent_config.name,\n", + " components=[identity],\n", + " )\n", + " agent = basic_agent.BasicAgent(\n", + " model,\n", + " mem,\n", + " agent_name=agent_config.name,\n", + " clock=clock,\n", + " verbose=True,\n", + " components=[identity, plan, reflection, time, summary_obs, current_obs],\n", + " )\n", + " agent.update()\n", + "\n", + " return agent" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qt8CK2mMbD7q" + }, + "source": [ + "## Configure and build the players" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QhAkMajsunp4" + }, + "outputs": [], + "source": [ + "def make_random_big_five()-\u003estr:\n", + " return str({\n", + " 'extraversion': random.randint(1, 10),\n", + " 'neuroticism': random.randint(1, 10),\n", + " 'openness': random.randint(1, 10),\n", + " 'conscientiousness': random.randint(1, 10),\n", + " 'agreeableness': random.randint(1, 10),\n", + " })\n", + "\n", + "scenario_premise = [\n", + "\n", + " (\n", + " 'Alice, Bob, Charlie and Dorothy are at the Sundrop Saloon. There '\n", + " + 'is a snow storm and they have to wait it out inside.'\n", + " ),\n", + "]\n", + "player_configs = [\n", + " formative_memories.AgentConfig(\n", + " name='Alice',\n", + " gender='female',\n", + " goal='Organise a street party in Riverbend.',\n", + " context=shared_context+'Alice is very socially active and knows everyone in town',\n", + " traits = make_random_big_five()\n", + " ),\n", + " formative_memories.AgentConfig(\n", + " name='Bob',\n", + " gender='male',\n", + " goal='Start a chess club in Riverbend.',\n", + " context=shared_context + 'Bob is a chess enthusiast',\n", + " traits = make_random_big_five()\n", + " ),\n", + " formative_memories.AgentConfig(\n", + " name='Charlie',\n", + " gender='male',\n", + " goal='Organise an ale festival at the Sundrop Saloon.',\n", + " context=shared_context + 'Charlie works at the Sundrop Saloon and loves real ales',\n", + " traits = make_random_big_five()\n", + " ),\n", + " formative_memories.AgentConfig(\n", + " name='Dorothy',\n", + " gender='female',\n", + " goal=(\n", + " 'Take students on a tour of Riverbend'\n", + " ' it is funny.'\n", + " ),\n", + " context=shared_context + 'Dorothy is a teacher at school in Riverbend',\n", + " traits = make_random_big_five()\n", + " ),\n", + " formative_memories.AgentConfig(\n", + " name='Ellen',\n", + " gender='female',\n", + " goal=(\n", + " 'Write a paper on the history of Riverbend.'\n", + " ),\n", + " context=shared_context + 'Ellen is a librarian in the library in Riverbend',\n", + " traits = make_random_big_five()\n", + " ),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5RU3ZV4oIknW" + }, + "outputs": [], + "source": [ + "NUM_PLAYERS = 3\n", + "\n", + "player_configs = player_configs[:NUM_PLAYERS]\n", + "player_goals = {player_config.name: player_config.goal for player_config in player_configs}\n", + "players = []\n", + "\n", + "with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_PLAYERS) as pool:\n", + " for agent in pool.map(build_agent, player_configs[:NUM_PLAYERS]):\n", + " players.append(agent)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2vt8ggYUrW8M" + }, + "source": [ + "## Build GM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "siwglxrc6z2j" + }, + "outputs": [], + "source": [ + "game_master_instructions = (\n", + " 'This is a social science experiment. It is structured as a '\n", + " 'tabletop roleplaying game (like dungeons and dragons). You are the '\n", + " 'game master. You will describe the current situation to the '\n", + " 'participants in the experiment and then on the basis of what you '\n", + " 'tell them they will suggest actions for the character they control. '\n", + " 'Aside from you, each other participant controls just one character. '\n", + " 'You are the game master so you may control any non-player '\n", + " 'character. You will track the state of the world and keep it '\n", + " 'consistent as time passes in the simulation and the participants '\n", + " 'take actions and change things in their world. Remember that this '\n", + " 'is a serious social science experiment. It is not just a game. It '\n", + " 'need not be fun for the participants. Always use third-person '\n", + " 'limited perspective, even when speaking directly to the participants.'\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3W65kHOKQwrv" + }, + "outputs": [], + "source": [ + "game_master_memory = associative_memory.AssociativeMemory(\n", + " embedder, importance_model_gm.importance, clock=clock.now)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bGNY_D7FID4I" + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-cxivChc633z" + }, + "outputs": [], + "source": [ + "# @title Create components and externalities\n", + "citizen_names = [player.name for player in players]\n", + "player_names = [player.name for player in players]\n", + "\n", + "instructions_construct = components.constant.ConstantConstruct(\n", + " game_master_instructions, 'Instructions'\n", + ")\n", + "facts_on_village = components.constant.ConstantConstruct(\n", + " ' '.join(shared_memories), 'General knowledge of Riverbend'\n", + ")\n", + "player_status = gm_components.player_status.PlayerStatus(\n", + " clock.now, model, game_master_memory, player_names\n", + ")\n", + "\n", + "relevant_events = gm_components.relevant_events.RelevantEvents(\n", + " clock.now, model, game_master_memory\n", + ")\n", + "time_display = gm_components.time_display.TimeDisplay(clock)\n", + "\n", + "\n", + "convo_externality = gm_components.conversation.Conversation(\n", + " players,\n", + " model,\n", + " clock=clock,\n", + " memory=game_master_memory,\n", + " burner_memory_factory=blank_memory_factory,\n", + " components=[player_status],\n", + " cap_nonplayer_characters=2,\n", + " game_master_instructions=game_master_instructions,\n", + " shared_context=shared_context,\n", + " verbose=False,\n", + ")\n", + "\n", + "direct_effect_externality = gm_components.direct_effect.DirectEffect(\n", + " players,\n", + " model=model,\n", + " memory=game_master_memory,\n", + " clock_now=clock.now,\n", + " verbose=False,\n", + " components=[player_status],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5SpNVmlh6_hp" + }, + "outputs": [], + "source": [ + "# @title Metrics\n", + "goal_metric = goal_achievement.GoalAchievementMetric(model, player_goals, clock, 'Goal achievement', verbose=False)\n", + "\n", + "metrics = [goal_metric]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "d_R2BVNOsAwa" + }, + "outputs": [], + "source": [ + "# @title Create the game master object\n", + "env = game_master.GameMaster(\n", + " model=model,\n", + " memory=game_master_memory,\n", + " clock=clock,\n", + " players=players,\n", + " components=[\n", + " instructions_construct,\n", + " facts_on_village,\n", + " player_status,\n", + " convo_externality,\n", + " direct_effect_externality,\n", + " relevant_events,\n", + " time_display,\n", + " ],\n", + " measurements=metrics,\n", + " randomise_initiative=True,\n", + " player_observes_event=False,\n", + " verbose=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d2u0bQ1MSCGd" + }, + "source": [ + "## The RUN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hdTRDaxEZZnN" + }, + "outputs": [], + "source": [ + "clock.set(START_TIME)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9IggLF1aH_hF" + }, + "outputs": [], + "source": [ + "#@title Initial observations and player location\n", + "for player in players:\n", + " player.observe(\n", + " f'{player.name} is at home, they have just woken up.'\n", + " )\n", + " game_master_memory.add(f'{player.name} is at their private home.')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2Bt87stq76gF" + }, + "outputs": [], + "source": [ + "# @title Expect about 2-3 minutes per step.\n", + "episode_length = 12 # @param {type: 'integer'}\n", + "for _ in range(episode_length):\n", + " env.step()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DnwvpvQ4bnFs" + }, + "source": [ + "## Summary and analysis of the episode" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "5U5FDXvs4HSr" + }, + "outputs": [], + "source": [ + "# @title Metrics plotting\n", + "tb = widgets.TabBar([metric.name() for metric in metrics])\n", + "\n", + "for metric in metrics:\n", + " with tb.output_to(metric.name()):\n", + " plotting.plot_metric_line(metric)\n", + " plotting.plot_metric_pie(metric)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "O4jp0xGXvOAJ" + }, + "outputs": [], + "source": [ + "# @title Summarize the entire story.\n", + "all_gm_memories = env._memory.retrieve_recent(k=10000, add_time=True)\n", + "\n", + "detailed_story = '\\n'.join(all_gm_memories)\n", + "print('len(detailed_story): ', len(detailed_story))\n", + "# print(detailed_story)\n", + "\n", + "episode_summary = model.sample_text(\n", + " f'Sequence of events:\\n{detailed_story}'+\n", + " '\\nNarratively summarize the above temporally ordered ' +\n", + " 'sequence of events. Write it as a news report. Summary:\\n',\n", + " max_characters=8000, max_tokens=8000, terminators=())\n", + "print(episode_summary)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ALG987t-6j-V" + }, + "outputs": [], + "source": [ + "# @title Summarise the perspective of each player\n", + "player_logs = []\n", + "player_log_names = []\n", + "for player in players:\n", + " name = player.name\n", + " detailed_story = '\\n'.join(player._memory.retrieve_recent(k=1000, add_time=True))\n", + " summary = ''\n", + " summary = model.sample_text(\n", + " f'Sequence of events that happened to {name}:\\n{detailed_story}'\n", + " '\\nWrite a short story that summarises these events.\\n'\n", + " ,\n", + " max_characters=8000, max_tokens=8000, terminators=())\n", + "\n", + " all_player_mem = player._memory.retrieve_recent(k=1000, add_time=True)\n", + " all_player_mem = ['Summary:', summary, 'Memories:'] + all_player_mem\n", + " player_html = html_lib.PythonObjectToHTMLConverter(all_player_mem).convert()\n", + " player_logs.append(player_html)\n", + " player_log_names.append(f'{name}')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UmPOvjVxddye" + }, + "source": [ + "#Build and display HTML log of the experiment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JyEoGgI05xI0" + }, + "outputs": [], + "source": [ + "history_sources = [env, direct_effect_externality, convo_externality]\n", + "histories_html = [html_lib.PythonObjectToHTMLConverter(history.get_history()).convert() for history in history_sources]\n", + "histories_names = [history.name() for history in history_sources]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XNJuo4Dwt5Ui" + }, + "outputs": [], + "source": [ + "gm_mem_html = html_lib.PythonObjectToHTMLConverter(all_gm_memories).convert()\n", + "\n", + "tabbed_html = html_lib.combine_html_pages(\n", + " histories_html + [gm_mem_html] + player_logs,\n", + " histories_names + ['GM'] + player_log_names,\n", + " summary=episode_summary,\n", + " title='Riverbend elections experiment',\n", + ")\n", + "\n", + "tabbed_html = html_lib.finalise_html(tabbed_html)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pumxvmrzANOq" + }, + "outputs": [], + "source": [ + "display.HTML(tabbed_html)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HX-M9Im_dneG" + }, + "source": [ + "#Interact with a specific player" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ESJ1l7_Kt5Uj" + }, + "outputs": [], + "source": [ + "sim_to_interact = 'Alice' # @param ['Alice', 'Bob','Charlie', 'Dorothy', 'Ellen'] {type:\"string\"}\n", + "user_identity = 'a close friend' # @param {type:\"string\"}\n", + "interaction_premise = f'{sim_to_interact} is talking to {user_identity}\\n' # @param {type:\"string\"}\n", + "\n", + "player_names = [player.name for player in players]\n", + "player_by_name = {player.name: player for player in players}\n", + "selected_player = player_by_name[sim_to_interact]\n", + "interrogation = interaction_premise" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5Q1cYflLt5Uj" + }, + "outputs": [], + "source": [ + "utterence_from_user = 'Did you win the elections?' # @param {type:\"string\"}\n", + "\n", + "interrogation += f'{user_identity}: {utterence_from_user}'\n", + "player_says = selected_player.say(interrogation)\n", + "interrogation += f'\\n{sim_to_interact}: {player_says}\\n'\n", + "print(interrogation)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "HX-M9Im_dneG" + ], + "last_runtime": { + "build_target": "", + "kind": "private" + }, + "private_outputs": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/village/metrics/__init__.py b/examples/village/metrics/__init__.py new file mode 100644 index 00000000..21637409 --- /dev/null +++ b/examples/village/metrics/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + diff --git a/examples/village/metrics/elections.py b/examples/village/metrics/elections.py new file mode 100644 index 00000000..a23f0904 --- /dev/null +++ b/examples/village/metrics/elections.py @@ -0,0 +1,86 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Metric of election outcome.""" + +from collections.abc import Sequence +from typing import Any + +from concordia.document import interactive_document +from concordia.examples.village.components import elections +from concordia.typing import clock as game_clock +from concordia.typing import metric + + +class Elections(metric.Metric): + """A metric to track votes in an election.""" + + def __init__( + self, + clock: game_clock.GameClock, + election_externality: elections.Elections, + name: str = 'Vote count', + verbose: bool = False, + writer=None, + ): + self._name = name + self._state = [] + self._clock = clock + self._verbose = verbose + self._writer = writer + self._election_tracker = election_externality + self._vote_count = election_externality.get_vote_count() + + def name( + self, + ) -> str: + """Returns the name of the measurement.""" + return self._name + + def update( + self, + observation: str, + acting_player_name: str, + doc: interactive_document.InteractiveDocument, + ) -> None: + self._vote_count = self._election_tracker.get_vote_count() + if acting_player_name not in self._vote_count.keys(): + return + answer = self._vote_count[acting_player_name] + answer_str = str(answer) + datum = { + 'time_str': self._clock.now().strftime('%H:%M:%S'), + 'clock_step': self._clock.get_step(), + 'step_metric': len(self._state), + 'value_float': answer, + 'value_str': answer_str, + 'player': acting_player_name, + } + if self._writer is not None: + self._writer.write(datum) + self._writer.flush() + + datum['time'] = self._clock.now() + self._state.append(datum) + + if self._verbose: + print(f'{self._name} of {acting_player_name}: {answer_str}') + + def state(self) -> list[dict[str, Any]]: + """Return the current state of all the tracked variables.""" + return self._state.copy() + + def get_scale(self) -> Sequence[str]: + return list(set([str(i) for i in self._vote_count.values()])) diff --git a/examples/village/riverbend_elections.ipynb b/examples/village/riverbend_elections.ipynb new file mode 100644 index 00000000..35b9803e --- /dev/null +++ b/examples/village/riverbend_elections.ipynb @@ -0,0 +1,907 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "sec0h5LfLh-Z" + }, + "source": [ + "```\n", + "Copyright 2023 DeepMind Technologies Limited.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zWgEkOAO9OVz" + }, + "source": [ + "# Riverbend Election Example\n", + "\n", + "An illustrative social simulation with 5 players which simulates the day of mayoral elections in an imaginary town caller Riverbend. First two players, Alice and Bob, are running for the mayor. The third player, Charlie, is trying to ruin Alices' reputation with disinformation. The last two players have no specific agenda, apart from voting in the election." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "J2TwJrZ08wXz" + }, + "source": [ + "## Init and import" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "-qLG5ExLqpWa" + }, + "outputs": [], + "source": [ + "# @title Imports\n", + "\n", + "import collections\n", + "import concurrent.futures\n", + "import datetime\n", + "import random\n", + "\n", + "from google.colab import widgets\n", + "from IPython import display\n", + "\n", + "from concordia.agents import basic_agent\n", + "from concordia.agents import components\n", + "from concordia.associative_memory import associative_memory\n", + "from concordia.associative_memory import blank_memories\n", + "from concordia.associative_memory import embedder_st5\n", + "from concordia.associative_memory import formative_memories\n", + "from concordia.associative_memory import importance_function\n", + "from concordia.clocks import game_clock\n", + "from concordia.environment import components as gm_components\n", + "from concordia.environment import game_master\n", + "from concordia.language_model import sax_model\n", + "from concordia.metrics import goal_achievement\n", + "from concordia.metrics import common_sense_morality\n", + "from concordia.metrics import opinion_of_others\n", + "from concordia.utils import html as html_lib\n", + "from concordia.utils import measurements as measurements_lib\n", + "from concordia.utils import plotting\n", + "\n", + "from concordia.examples.village.components import elections\n", + "from concordia.examples.village.metrics import elections as elections_metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "I3OtW8flCJSC" + }, + "outputs": [], + "source": [ + "# @title Setup sentence encoder\n", + "embedder = embedder_st5.EmbedderST5()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "cugwvFIKv5AS" + }, + "outputs": [], + "source": [ + "# @title SAX Language Model\n", + "\n", + "# Add path to your SAX server here:\n", + "SAX_PATH = '' # @param {type:\"string\"}\n", + "DEFAULT_MAX_TOKENS = 300 # @param {type: 'integer'}\n", + "DEFAULT_TIMEOUT_SECONDS = 60 # @param {type: 'number'}\n", + "\n", + "model = sax_model.SAXLanguageModel(SAX_PATH)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z9HYjZgyakc_" + }, + "source": [ + "## Configuring the generic knowledge of players and GM." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "b8vWoQ6by51N" + }, + "outputs": [], + "source": [ + "# @title Generic memories are memories that all players and GM share.\n", + "\n", + "shared_memories = [\n", + " 'There is a hamlet named Riverbend.',\n", + " 'Riverbend is an idyllic rural town.',\n", + " 'The river Solripple runs through the village of Riverbend.',\n", + " 'The Solripple is a mighty river.',\n", + " 'Riverbend has a temperate climate.',\n", + " 'Riverbend has a main street.',\n", + " 'There is a guitar store on Main street Riverbend.',\n", + " 'There is a grocery store on Main street Riverbend.',\n", + " 'There is a school on Main street Riverbend.',\n", + " 'There is a library on Main street Riverbend.',\n", + " 'Riverbend has only one pub.',\n", + " 'There is a pub on Main street Riverbend called The Sundrop Saloon.',\n", + " 'Town hall meetings often take place at The Sundrop Saloon.',\n", + " 'Riverbend does not have a park',\n", + " 'The main crop grown on the farms near Riverbend is alfalfa.',\n", + " 'Farms near Riverbend depend on water from the Solripple river.',\n", + " (\n", + " 'The local newspaper recently reported that someone has been dumping '\n", + " + 'dangerous industrial chemicals in the Solripple river.'\n", + " ),\n", + " 'All named characters are citizens. ',\n", + " # 'All citizens are automatically candidates in all elections. ',\n", + " 'There is no need to register in advance to be on the ballot.',\n", + "]\n", + "\n", + "# The generic context will be used for the NPC context. It reflects general\n", + "# knowledge and is possessed by all characters.\n", + "shared_context = model.sample_text(\n", + " 'Summarize the following passage in a concise and insightful fashion:\\n'\n", + " + '\\n'.join(shared_memories)\n", + " + '\\n'\n", + " + 'Summary:'\n", + ")\n", + "print(shared_context)\n", + "importance_model = importance_function.ConstantImportanceModel()\n", + "importance_model_gm = importance_function.ConstantImportanceModel()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "TeVYseoD2WYa" + }, + "outputs": [], + "source": [ + "#@title Make the clock\n", + "SETUP_TIME = datetime.datetime(hour=8, year=2024, month=9, day=1)\n", + "\n", + "START_TIME = datetime.datetime(hour=9, year=2024, month=10, day=1)\n", + "clock = game_clock.MultiIntervalClock(\n", + " start=SETUP_TIME,\n", + " step_sizes=[datetime.timedelta(hours=1), datetime.timedelta(seconds=10)])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YBCXUQ8sayzj" + }, + "source": [ + "## Functions to build the players" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fNpnn9QY4IN6" + }, + "outputs": [], + "source": [ + "blank_memory_factory = blank_memories.MemoryFactory(\n", + " model=model,\n", + " embedder=embedder,\n", + " importance=importance_model.importance,\n", + " clock_now=clock.now,\n", + ")\n", + "\n", + "formative_memory_factory = formative_memories.FormativeMemoryFactory(\n", + " model=model,\n", + " shared_memories=shared_memories,\n", + " blank_memory_factory_call=blank_memory_factory.make_blank_memory,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "As465DbcsAwZ" + }, + "outputs": [], + "source": [ + "def build_a_citizen(agent_config,\n", + " player_names: list[str],\n", + " measurements: measurements_lib.Measurements | None = None):\n", + "\n", + " mem = formative_memory_factory.make_memories(agent_config)\n", + "\n", + " # Build the player.\n", + "\n", + " time = components.report_state.ReportState(\n", + " name='Current time',\n", + " get_state=clock.current_time_interval_str)\n", + "\n", + " somatic_state = components.somatic_state.SomaticState(\n", + " model, mem, agent_config.name, clock\n", + " )\n", + " identity = components.identity.SimIdentity(model, mem, agent_config.name)\n", + " goal_component = components.constant.ConstantConstruct(state=agent_config.goal)\n", + " plan = components.plan.SimPlan(\n", + " model,\n", + " mem,\n", + " agent_config.name,\n", + " components=[identity],\n", + " goal=goal_component,\n", + " verbose=False,\n", + " )\n", + " current_obs = components.observation.Observation(agent_config.name, mem)\n", + " summary_obs = components.observation.ObservationSummary(\n", + " model=model,\n", + " agent_name=agent_config.name,\n", + " components=[identity],\n", + " )\n", + " goal_metric = goal_achievement.GoalAchievementMetric(\n", + " model=model,\n", + " player_name=agent_config.name,\n", + " player_goal=agent_config.goal,\n", + " clock=clock,\n", + " name='Goal Achievement',\n", + " measurements=measurements,\n", + " channel='goal_achievement',\n", + " verbose=False,\n", + " )\n", + " morality_metric = common_sense_morality.CommonSenseMoralityMetric(\n", + " model=model,\n", + " player_name=agent_config.name,\n", + " clock=clock,\n", + " name='Morality',\n", + " verbose=False,\n", + " measurements=measurements,\n", + " channel='common_sense_morality',\n", + " )\n", + " agent = basic_agent.BasicAgent(\n", + " model,\n", + " mem,\n", + " agent_name=agent_config.name,\n", + " clock=clock,\n", + " verbose=True,\n", + " components=[identity, plan, somatic_state, summary_obs, current_obs,time,\n", + " goal_metric, morality_metric],\n", + " )\n", + " reputation_metric = opinion_of_others.OpinionOfOthersMetric(\n", + " model=model,\n", + " player_name=agent_config.name,\n", + " player_names=player_names,\n", + " context_fn=agent.state,\n", + " clock=clock,\n", + " name='Opinion',\n", + " verbose=False,\n", + " measurements=measurements,\n", + " channel='opinion_of_others',\n", + " question='What is {opining_player}\\'s opinion of {of_player}?',\n", + " )\n", + " agent.add_component(reputation_metric)\n", + "\n", + " return agent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hgM4a234xqaJ" + }, + "outputs": [], + "source": [ + "def context_from_role(player_name, role, victim='Alice'):\n", + " context = f'{shared_context}\\n{player_name} grew up in Riverbend.'\n", + " # Divide players into two classes, half are biased toward the status quo and\n", + " # the other half want change.\n", + " if 'conservative' in role:\n", + " context = (\n", + " f'{context} {player_name} feels strongly that Riverbend is '\n", + " + 'already a great place to live and does not need to change.'\n", + " )\n", + " if 'progressive' in role:\n", + " context = (\n", + " f'{context} {player_name} feels strongly that Riverbend and '\n", + " + 'its local government are in dire need of reform.'\n", + " )\n", + " # The first two players are mayoral candidates.\n", + " if 'candidate' in role:\n", + " context = (\n", + " f'{context} Because of this, {player_name} plans to run for '\n", + " + 'mayor of Riverbend.'\n", + " )\n", + " if 'active_voter' in role:\n", + " context = (\n", + " f'{context} {player_name} does not plan to run for mayor of Riverbend,'\n", + " ' but would definitely vote in the election.'\n", + " )\n", + "\n", + " # The third player has been hired to ruin the first player's reputation.\n", + " if 'corrupt' in role:\n", + " context = (\n", + " f'{context}\\n'\n", + " + f\"{player_name} has been hired to ruin {victim}'s \"\n", + " + 'reputation.\\n'\n", + " + f'{player_name} was hired by an anonymous email so '\n", + " + f\"they do not know who hired them to ruin {victim}'s \"\n", + " + 'reputation or what their motivation may be.\\n'\n", + " + f'{player_name} was given fake compromising material on'\n", + " f' {victim}.\\n{player_name} was offered a substantial sum of'\n", + " ' money to spread compromising materials '\n", + " + f\"to ruin {victim}'s reputation.\"\n", + " )\n", + "\n", + " return context\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qt8CK2mMbD7q" + }, + "source": [ + "## Configure and build the players" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TTrhJ3hM3bnB" + }, + "outputs": [], + "source": [ + "NUM_PLAYERS = 5\n", + "victim = 'Alice'\n", + "\n", + "def make_random_big_five()-\u003estr:\n", + " return str({\n", + " 'extraversion': random.randint(1, 10),\n", + " 'neuroticism': random.randint(1, 10),\n", + " 'openness': random.randint(1, 10),\n", + " 'conscientiousness': random.randint(1, 10),\n", + " 'agreeableness': random.randint(1, 10),\n", + " })\n", + "\n", + "player_configs = [\n", + " formative_memories.AgentConfig(\n", + " name='Alice',\n", + " gender='female',\n", + " goal='Win the election and become the mayor of Riverbend',\n", + " context=context_from_role('Alice', {'candidate', 'conservative'}),\n", + " traits = make_random_big_five()\n", + " ),\n", + " formative_memories.AgentConfig(\n", + " name='Bob',\n", + " gender='male',\n", + " goal='Win the election and become the mayor of Riverbend.',\n", + " context=context_from_role('Bob', {'candidate', 'progressive'}),\n", + " traits = make_random_big_five()\n", + " ),\n", + " formative_memories.AgentConfig(\n", + " name='Charlie',\n", + " gender='male',\n", + " goal=f\"Ruin {victim}'s reputation\",\n", + " context=context_from_role('Charlie', {'corrupt'}, victim),\n", + " traits = make_random_big_five()\n", + " ),\n", + " formative_memories.AgentConfig(\n", + " name='Dorothy',\n", + " gender='female',\n", + " goal='Have a good day and vote in the election.',\n", + " context=context_from_role(\n", + " 'Dorothy', {'active_voter', 'progressive'}\n", + " ),\n", + " traits = make_random_big_five()\n", + " ),\n", + " formative_memories.AgentConfig(\n", + " name='Ellen',\n", + " gender='female',\n", + " goal=(\n", + " 'Have a good day and vote in the election.'\n", + " ),\n", + " context=context_from_role('Ellen', {'active_voter', 'conservative'}),\n", + " traits = make_random_big_five()\n", + " ),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EemkNGWw3wXo" + }, + "outputs": [], + "source": [ + "NUM_PLAYERS = 5\n", + "\n", + "player_configs = player_configs[:NUM_PLAYERS]\n", + "player_goals = {player_config.name: player_config.goal for player_config in player_configs}\n", + "players = []\n", + "measurements = measurements_lib.Measurements()\n", + "\n", + "player_names = [player.name for player in player_configs][:NUM_PLAYERS]\n", + "with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_PLAYERS) as pool:\n", + " for agent in pool.map(build_a_citizen,\n", + " player_configs[:NUM_PLAYERS],\n", + " # All players get the same `player_names`.\n", + " [player_names] * NUM_PLAYERS,\n", + " # All players get the same `measurements` object.\n", + " [measurements] * NUM_PLAYERS):\n", + " players.append(agent)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2vt8ggYUrW8M" + }, + "source": [ + "## Build GM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "siwglxrc6z2j" + }, + "outputs": [], + "source": [ + "game_master_instructions = (\n", + " 'This is a social science experiment. It is structured as a '\n", + " 'tabletop roleplaying game (like dungeons and dragons). You are the '\n", + " 'game master. You will describe the current situation to the '\n", + " 'participants in the experiment and then on the basis of what you '\n", + " 'tell them they will suggest actions for the character they control. '\n", + " 'Aside from you, each other participant controls just one character. '\n", + " 'You are the game master so you may control any non-player '\n", + " 'character. You will track the state of the world and keep it '\n", + " 'consistent as time passes in the simulation and the participants '\n", + " 'take actions and change things in their world. Remember that this '\n", + " 'is a serious social science experiment. It is not just a game. It '\n", + " 'need not be fun for the participants. Always use third-person '\n", + " 'limited perspective, even when speaking directly to the participants.'\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3W65kHOKQwrv" + }, + "outputs": [], + "source": [ + "game_master_memory = associative_memory.AssociativeMemory(\n", + " embedder, importance_model_gm.importance, clock=clock.now)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bGNY_D7FID4I" + }, + "outputs": [], + "source": [ + "for player in players:\n", + " game_master_memory.add(f'{player.name} is at their private home.')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "-cxivChc633z" + }, + "outputs": [], + "source": [ + "# @title Create components and externalities\n", + "citizen_names = [player.name for player in players]\n", + "player_names = [player.name for player in players]\n", + "\n", + "instructions_construct = components.constant.ConstantConstruct(game_master_instructions, 'Instructions')\n", + "facts_on_village = components.constant.ConstantConstruct(' '.join(shared_memories), 'General knowledge of Riverbend')\n", + "player_status = gm_components.player_status.PlayerStatus(clock.now, model, game_master_memory, player_names)\n", + "\n", + "relevant_events = gm_components.relevant_events.RelevantEvents(clock.now, model, game_master_memory)\n", + "time_display = gm_components.time_display.TimeDisplay(clock)\n", + "\n", + "election_externality = elections.Elections(\n", + " model=model,\n", + " clock_now=clock.now,\n", + " memory=game_master_memory,\n", + " voters=players,\n", + " candidates=['Alice', 'Bob'],\n", + " verbose=True,\n", + " measurements=measurements,\n", + ")\n", + "\n", + "mem_factory = blank_memories.MemoryFactory(\n", + " model,\n", + " embedder,\n", + " importance_model_gm.importance,\n", + " clock_now=clock.now,\n", + ")\n", + "\n", + "convo_externality = gm_components.conversation.Conversation(\n", + " players,\n", + " model,\n", + " memory=game_master_memory,\n", + " clock=clock,\n", + " burner_memory_factory=mem_factory,\n", + " components=[player_status],\n", + " cap_nonplayer_characters=2,\n", + " game_master_instructions=game_master_instructions,\n", + " shared_context=shared_context,\n", + " verbose=True,\n", + ")\n", + "\n", + "\n", + "direct_effect_externality = gm_components.direct_effect.DirectEffect(\n", + " players,\n", + " model=model,\n", + " memory=game_master_memory,\n", + " clock_now=clock.now,\n", + " verbose=False,\n", + " components=[player_status]\n", + ")\n", + "\n", + "\n", + "TIME_POLLS_OPEN = datetime.datetime(hour=14, year=2024, month=10, day=1)\n", + "TIME_POLLS_CLOSE = datetime.datetime(hour=20, year=2024, month=10, day=1)\n", + "schedule = {\n", + " 'start': gm_components.schedule.EventData(\n", + " time=START_TIME,\n", + " description=None),\n", + " 'election': gm_components.schedule.EventData(\n", + " time=datetime.datetime(hour=13, year=2024, month=10, day=1),\n", + " description=(\n", + " 'The town of Riverbend is now holding an election to determine ' +\n", + " 'who will become the mayor. ' +\n", + " f'Polls will open at {TIME_POLLS_OPEN}.')),\n", + " 'election_polls_open': gm_components.schedule.EventData(\n", + " time=TIME_POLLS_OPEN,\n", + " description=(\n", + " 'The election is happening now. Polls are open. Everyone may ' +\n", + " 'go to a polling place and cast their vote. ' +\n", + " f'Polls will close at {TIME_POLLS_CLOSE}.'),\n", + " trigger=election_externality.open_polls),\n", + " 'election_polls_close': gm_components.schedule.EventData(\n", + " time=TIME_POLLS_CLOSE,\n", + " description=(\n", + " 'The election is over. Polls are now closed. The results will ' +\n", + " 'now be tallied and a winner declared.'),\n", + " trigger=election_externality.declare_winner)\n", + "}\n", + "\n", + "schedule_construct = gm_components.schedule.Schedule(clock=clock, schedule=schedule)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "d_R2BVNOsAwa" + }, + "outputs": [], + "source": [ + "# @title Create the game master object\n", + "env = game_master.GameMaster(\n", + " model=model,\n", + " memory=game_master_memory,\n", + " clock=clock,\n", + " players=players,\n", + " components=[\n", + " instructions_construct,\n", + " facts_on_village,\n", + " player_status,\n", + " schedule_construct,\n", + " election_externality,\n", + " convo_externality,\n", + " direct_effect_externality,\n", + " relevant_events,\n", + " time_display,\n", + " ],\n", + " randomise_initiative=True,\n", + " player_observes_event=False,\n", + " verbose=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d2u0bQ1MSCGd" + }, + "source": [ + "## The RUN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hdTRDaxEZZnN" + }, + "outputs": [], + "source": [ + "clock.set(START_TIME)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9IggLF1aH_hF" + }, + "outputs": [], + "source": [ + "for player in players:\n", + " player.observe(\n", + " f'{player.name} is at home, they have just woken up. Mayoral elections are going to be'\n", + " f' held today. Polls will open at {TIME_POLLS_OPEN} and close at {TIME_POLLS_CLOSE}.'\n", + " )\n", + "with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_PLAYERS) as pool:\n", + " for player in players:\n", + " pool.submit(player.update)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "2Bt87stq76gF" + }, + "outputs": [], + "source": [ + "# @title Expect about 2-3 minutes per step.\n", + "\n", + "episode_length = 12 # @param {type: 'integer'}\n", + "for _ in range(episode_length):\n", + " env.step()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DnwvpvQ4bnFs" + }, + "source": [ + "## Summary and analysis of the episode" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5U5FDXvs4HSr" + }, + "outputs": [], + "source": [ + "# @title Metrics plotting\n", + "\n", + "colab_import.reload_module(plotting)\n", + "\n", + "group_by = collections.defaultdict(lambda: 'player')\n", + "group_by['opinion_of_others'] = 'of_player'\n", + "\n", + "tb = widgets.TabBar([channel for channel in measurements.available_channels()])\n", + "for channel in measurements.available_channels():\n", + " with tb.output_to(channel):\n", + " plotting.plot_line_measurement_channel(measurements, channel,\n", + " group_by=group_by[channel],\n", + " xaxis='time_str')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "O4jp0xGXvOAJ" + }, + "outputs": [], + "source": [ + "# @title Summarize the entire story.\n", + "all_gm_memories = env._memory.retrieve_recent(k=10000, add_time=True)\n", + "\n", + "detailed_story = '\\n'.join(all_gm_memories)\n", + "print('len(detailed_story): ', len(detailed_story))\n", + "# print(detailed_story)\n", + "\n", + "episode_summary = model.sample_text(\n", + " f'Sequence of events:\\n{detailed_story}'+\n", + " '\\nNarratively summarize the above temporally ordered ' +\n", + " 'sequence of events. Write it as a news report. Summary:\\n',\n", + " max_characters=8000, max_tokens=8000, terminators=())\n", + "print(episode_summary)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ALG987t-6j-V" + }, + "outputs": [], + "source": [ + "# @title Summarise the perspective of each player\n", + "player_logs = []\n", + "player_log_names = []\n", + "for player in players:\n", + " name = player.name\n", + " detailed_story = '\\n'.join(player._memory.retrieve_recent(k=1000, add_time=True))\n", + " summary = ''\n", + " summary = model.sample_text(\n", + " f'Sequence of events that happened to {name}:\\n{detailed_story}'\n", + " '\\nWrite a short story that summarises these events.\\n'\n", + " ,\n", + " max_characters=8000, max_tokens=8000, terminators=())\n", + "\n", + " all_player_mem = player._memory.retrieve_recent(k=1000, add_time=True)\n", + " all_player_mem = ['Summary:', summary, 'Memories:'] + all_player_mem\n", + " player_html = html_lib.PythonObjectToHTMLConverter(all_player_mem).convert()\n", + " player_logs.append(player_html)\n", + " player_log_names.append(f'{name}')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UmPOvjVxddye" + }, + "source": [ + "#Build and display HTML log of the experiment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JyEoGgI05xI0" + }, + "outputs": [], + "source": [ + "history_sources = [env, direct_effect_externality, convo_externality, election_externality]\n", + "histories_html = [html_lib.PythonObjectToHTMLConverter(history.get_history()).convert() for history in history_sources]\n", + "histories_names = [history.name() for history in history_sources]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XNJuo4Dwt5Ui" + }, + "outputs": [], + "source": [ + "gm_mem_html = html_lib.PythonObjectToHTMLConverter(all_gm_memories).convert()\n", + "\n", + "tabbed_html = html_lib.combine_html_pages(\n", + " histories_html + [gm_mem_html] + player_logs,\n", + " histories_names + ['GM'] + player_log_names,\n", + " summary=episode_summary,\n", + " title='Riverbend elections experiment',\n", + ")\n", + "\n", + "tabbed_html = html_lib.finalise_html(tabbed_html)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pumxvmrzANOq" + }, + "outputs": [], + "source": [ + "display.HTML(tabbed_html)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HX-M9Im_dneG" + }, + "source": [ + "#Interact with a specific player" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ESJ1l7_Kt5Uj" + }, + "outputs": [], + "source": [ + "sim_to_interact = 'Alice' # @param ['Alice', 'Bob','Charlie', 'Dorothy', 'Ellen'] {type:\"string\"}\n", + "user_identity = 'a close friend' # @param {type:\"string\"}\n", + "interaction_premise = f'{sim_to_interact} is talking to {user_identity}\\n' # @param {type:\"string\"}\n", + "\n", + "player_names = [player.name for player in players]\n", + "player_by_name = {player.name: player for player in players}\n", + "selected_player = player_by_name[sim_to_interact]\n", + "interrogation = interaction_premise" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "5Q1cYflLt5Uj" + }, + "outputs": [], + "source": [ + "utterence_from_user = 'Did you win the elections?' # @param {type:\"string\"}\n", + "\n", + "interrogation += f'{user_identity}: {utterence_from_user}'\n", + "player_says = selected_player.say(interrogation)\n", + "interrogation += f'\\n{sim_to_interact}: {player_says}\\n'\n", + "print(interrogation)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qkdsaKmKjASJ" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "", + "kind": "local" + }, + "private_outputs": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..02ad3de0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,33 @@ +[build-system] +requires = ["setuptools>=42"] +build-backend = "setuptools.build_meta" + +[tool.isort] +profile = "google" +# TODO remove once https://github.com/PyCQA/isort/pull/2149 submitted. +line_length = 1000 +single_line_exclusions = ["collections.abc", "typing"] +known_thirdparty = ["concordia"] + +[tool.pyink] +line-length = 80 +preview = true +pyink-indentation = 2 +pyink-use-majority-quotes = true + +[tool.pytest.ini_options] +required_plugins = ["pytest-xdist"] +addopts = "-n auto" +testpaths = ["concordia", "examples"] + +[tool.pytype] +python_version = "3.10" +inputs = ["concordia", "examples"] +# Keep going past errors to analyze as many files as possible. +keep_going = true +# Run N jobs in parallel. When 'auto' is used, this will be equivalent to the +# number of CPUs on the host system. +jobs = 'auto' +# Use the enum overlay for more precise enum checking. This flag is temporary +# and will be removed once this behavior is enabled by default. +use_enum_overlay = true diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..a9d9547e --- /dev/null +++ b/setup.py @@ -0,0 +1,83 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Install script for setuptools.""" + +import setuptools + +setuptools.setup( + name='dm-concordia', + version='1.0.0.dev.0', + license='Apache 2.0', + license_files=['LICENSE'], + url='https://github.com/google-deepmind/concordia', + download_url='https://github.com/google-deepmind/concordia/releases', + author='DeepMind', + author_email='noreply@google.com', + description=( + 'A library for building a generative model of social interacions.' + ), + keywords=( + 'multi-agent agent-based-simulation generative-agents python' + ' machine-learning' + ), + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: POSIX :: Linux', + 'Operating System :: MacOS :: MacOS X', + 'Programming Language :: Python :: 3 :: Only', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + ], + package_dir={ + 'concordia': 'concordia', + }, + package_data={}, + python_requires='>=3.10', + install_requires=[ + # TODO: b/312199199 - remove some requirements. + 'absl-py', + 'ipython', + 'matplotlib', + 'numpy', + 'pandas<=1.5.3', + 'python_dateutil', + 'reactivex', + 'retry', + 'rwlock', + 'saxml', + 'scipy', + 'tensorflow', + 'tensorflow_hub', + 'termcolor', + ], + extras_require={ + # Used in development. + 'dev': [ + 'build', + 'isort', + 'pipreqs', + 'pyink', + 'pylint', + 'pytest-xdist', + 'pytype', + ], + }, +)