diff --git a/.ci/FILE_HEADER b/.ci/FILE_HEADER
deleted file mode 100644
index 9ae76d4864..0000000000
--- a/.ci/FILE_HEADER
+++ /dev/null
@@ -1,2 +0,0 @@
-Copyright 2022 MosaicML Composer authors
-SPDX-License-Identifier: Apache-2.0
diff --git a/CODEOWNERS b/.github/CODEOWNERS
similarity index 91%
rename from CODEOWNERS
rename to .github/CODEOWNERS
index a183caa01f..b193288b3e 100644
--- a/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -20,8 +20,8 @@
/composer/algorithms/ @mosaicml/composer-team-eng
/composer/cli/ @mosaicml/composer-team-eng
/composer/datasets/ @mosaicml/composer-team-eng
-/composer/functional/ @mosaicml/composer-team-eng @dblalock
-/composer/loggers/ @mosaicml/composer-team-eng @eracah @dakinggg
+/composer/functional/ @mosaicml/composer-team-eng
+/composer/loggers/ @mosaicml/composer-team-eng
/composer/loss/ @mosaicml/composer-team-eng
/composer/metrics/ @mosaicml/composer-team-eng
/composer/models/ @mosaicml/composer-team-eng
diff --git a/.github/dependabot.yml b/.github/dependabot.yml
index b433af6b87..825fe27053 100644
--- a/.github/dependabot.yml
+++ b/.github/dependabot.yml
@@ -5,8 +5,8 @@
version: 2
updates:
- - package-ecosystem: "pip" # See documentation for possible values
- directory: "/" # Location of package manifests
- schedule:
- interval: "weekly"
- open-pull-requests-limit: 5
+- package-ecosystem: "pip" # See documentation for possible values
+ directory: "/" # Location of package manifests
+ schedule:
+ interval: "weekly"
+ open-pull-requests-limit: 5
diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml
index 317173e094..20bbf327b7 100644
--- a/.github/workflows/code-quality.yaml
+++ b/.github/workflows/code-quality.yaml
@@ -1,42 +1,30 @@
name: Code Quality Checks
on:
- push:
- branches:
- - dev
- - main
- - release/**
- pull_request:
workflow_call:
- workflow_dispatch:
-# Cancel old runs when a new commit is pushed to the same branch if not on main or dev
-concurrency:
- group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
- cancel-in-progress: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/dev' }}
+ inputs:
+ python_version:
+ required: true
+ type: string
+ pip_deps:
+ required: true
+ type: string
defaults:
run:
working-directory: .
jobs:
code-quality:
runs-on: ubuntu-20.04
- timeout-minutes: 10
- strategy:
- matrix:
- python_version:
- - "3.8"
- - "3.9"
- - "3.10"
- pip_deps:
- - "[dev]"
+ timeout-minutes: 15
steps:
- - uses: actions/checkout@v3
- - uses: actions/setup-python@v4
- with:
- python-version: ${{ matrix.python_version }}
- - name: Setup
- run: |
- set -ex
- python -m pip install --upgrade 'pip<23' wheel
- python -m pip install --upgrade .${{ matrix.pip_deps }}
- - name: Run checks
- run: |
- pre-commit run --all-files
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ inputs.python_version }}
+ - name: Setup
+ run: |
+ set -ex
+ python -m pip install --upgrade 'pip<23' wheel
+ python -m pip install --upgrade .${{ inputs.pip_deps }}
+ - name: Run checks
+ run: |
+ pre-commit run --all-files
diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml
index 1b1ccfbade..151179d524 100644
--- a/.github/workflows/codeql-analysis.yml
+++ b/.github/workflows/codeql-analysis.yml
@@ -13,12 +13,12 @@ name: "CodeQL"
on:
push:
- branches: [ dev, main ]
+ branches: [dev, main]
pull_request:
# The branches below must be a subset of the branches above
- branches: [ dev, main ]
+ branches: [dev, main]
schedule:
- - cron: '0 9 * * 1' # Every Monday at 09:00 (9:00 AM)
+ - cron: "0 9 * * 1" # Every Monday at 09:00 (9:00 AM)
jobs:
analyze:
@@ -32,9 +32,11 @@ jobs:
strategy:
fail-fast: false
matrix:
- language: [ 'python' ]
- # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
- # Learn more about CodeQL language support at https://git.io/codeql-language-support
+ language: ["python"]
+ # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript',
+ # 'python', 'ruby' ]
+ # Learn more about CodeQL language support at
+ # https://git.io/codeql-language-support
steps:
- name: Checkout repository
@@ -45,24 +47,28 @@ jobs:
uses: github/codeql-action/init@v2
with:
languages: ${{ matrix.language }}
- # If you wish to specify custom queries, you can do so here or in a config file.
- # By default, queries listed here will override any specified in a config file.
- # Prefix the list here with "+" to use these queries and those in the config file.
+ # If you wish to specify custom queries, you can do so here or in a
+ # config file.
+ # By default, queries listed here will override any specified in a
+ # config file.
+ # Prefix the list here with "+" to use these queries and those in the
+ # config file.
# queries: ./path/to/local/query, your-org/your-repo/queries@main
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
- # If this step fails, then you should remove it and run the build manually (see below)
+ # If this step fails, then you should remove it and run the build manually
+ # (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v2
# âšī¸ Command-line programs to run using the OS shell.
# đ https://git.io/JvXDl
- # âī¸ If the Autobuild fails above, remove it and uncomment the following three lines
- # and modify them (or add more) to build your code if your project
- # uses a compiled language
+ # âī¸ If the Autobuild fails above, remove it and uncomment the following
+ # three lines and modify them (or add more) to build your code if your
+ # project uses a compiled language
- #- run: |
+ # - run: |
# make bootstrap
# make release
diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml
index 724497b1d7..f89d67ec39 100644
--- a/.github/workflows/coverage.yaml
+++ b/.github/workflows/coverage.yaml
@@ -10,23 +10,23 @@ jobs:
timeout-minutes: 5
runs-on: ubuntu-latest
steps:
- - name: Checkout Repo
- uses: actions/checkout@v3
- - name: Setup
- run: |
- set -ex
- python -m pip install --upgrade 'pip<23' wheel
- pip install coverage[toml]==6.5.0
- - name: Download artifacts
- uses: actions/download-artifact@v3
- with:
- path: ${{ inputs.download-path }}
- - name: Generate coverage report
- run: |
- set -ex
+ - name: Checkout Repo
+ uses: actions/checkout@v3
+ - name: Setup
+ run: |
+ set -ex
+ python -m pip install --upgrade 'pip<23' wheel
+ pip install coverage[toml]==6.5.0
+ - name: Download artifacts
+ uses: actions/download-artifact@v3
+ with:
+ path: ${{ inputs.download-path }}
+ - name: Generate coverage report
+ run: |
+ set -ex
- # Flatten the coverage files
- ls ${{ inputs.download-path }} | while read x; do mv ${{ inputs.download-path }}/$x/.coverage .coverage.$x; done
+ # Flatten the coverage files
+ ls ${{ inputs.download-path }} | while read x; do mv ${{ inputs.download-path }}/$x/.coverage .coverage.$x; done
- python -m coverage combine
- python -m coverage report
+ python -m coverage combine
+ python -m coverage report
diff --git a/.github/workflows/daily.yaml b/.github/workflows/daily.yaml
index 588572c18f..3c65b0f4fa 100644
--- a/.github/workflows/daily.yaml
+++ b/.github/workflows/daily.yaml
@@ -1,12 +1,12 @@
name: Daily
on:
schedule:
- - cron: '30 2 * * *' # 2:30 every day
+ - cron: "30 2 * * *" # 2:30 every day
push:
branches:
- - dev
- - main
- - release/**
+ - dev
+ - main
+ - release/**
workflow_dispatch:
# Cancel old runs when a new commit is pushed to the same branch if not on main or dev
concurrency:
@@ -18,66 +18,56 @@ jobs:
strategy:
matrix:
include:
- - name: 'cpu-3.10-1.13'
- container: mosaicml/pytorch:1.13.1_cpu-python3.10-ubuntu20.04
- markers: 'not daily and (remote or not remote) and not gpu and not vision and not doctest'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
- - name: 'cpu-3.10-2.0'
- container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
- markers: 'not daily and (remote or not remote) and not gpu and not vision and not doctest'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
- - name: 'cpu-3.10-2.1'
- container: mosaicml/pytorch:2.1.0_cpu-python3.10-ubuntu20.04
- markers: 'not daily and (remote or not remote) and not gpu and not vision and not doctest'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
- - name: 'cpu-3.10-2.1-composer'
- container: mosaicml/pytorch:2.1.0_cpu-python3.10-ubuntu20.04
- markers: 'not daily and (remote or not remote) and not gpu and not vision and not doctest'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'composer'
- - name: 'cpu-vision'
- container: mosaicml/pytorch_vision:1.13.1_cpu-python3.10-ubuntu20.04
- markers: 'not daily and (remote or not remote) and not gpu and vision and not doctest'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
- - name: 'cpu-doctest'
- container: mosaicml/pytorch_vision:1.13.1_cpu-python3.10-ubuntu20.04
- markers: 'not daily and (remote or not remote) and not gpu and not vision and doctest'
- pytest_command: 'coverage run -m pytest tests/test_docs.py'
- composer_package_name: 'mosaicml'
- - name: 'daily-cpu-3.10-1.13'
- container: mosaicml/pytorch:1.13.1_cpu-python3.10-ubuntu20.04
- markers: 'daily and (remote or not remote) and not gpu and not vision and not doctest'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
- - name: 'daily-cpu-3.10-2.0'
- container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
- markers: 'daily and (remote or not remote) and not gpu and not vision and not doctest'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
- - name: 'daily-cpu-3.10-2.1'
- container: mosaicml/pytorch:2.1.0_cpu-python3.10-ubuntu20.04
- markers: 'daily and (remote or not remote) and not gpu and not vision and not doctest'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
- - name: 'daily-cpu-3.10-2.1-composer'
- container: mosaicml/pytorch:2.1.0_cpu-python3.10-ubuntu20.04
- markers: 'daily and (remote or not remote) and not gpu and not vision and not doctest'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'composer'
- - name: 'daily-cpu-vision'
- container: mosaicml/pytorch_vision:1.13.1_cpu-python3.10-ubuntu20.04
- markers: 'daily and (remote or not remote) and not gpu and vision and not doctest'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
- - name: 'daily-cpu-doctest'
- container: mosaicml/pytorch_vision:1.13.1_cpu-python3.10-ubuntu20.04
- markers: 'daily and (remote or not remote) and not gpu and not vision and doctest'
- pytest_command: 'coverage run -m pytest tests/test_docs.py'
- composer_package_name: 'mosaicml'
+ - name: cpu-3.10-2.0
+ container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
+ markers: not daily and (remote or not remote) and not gpu and not doctest
+ pytest_command: coverage run -m pytest
+ composer_package_name: mosaicml
+ - name: cpu-3.10-2.1
+ container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
+ markers: not daily and (remote or not remote) and not gpu and not doctest
+ pytest_command: coverage run -m pytest
+ composer_package_name: mosaicml
+ - name: cpu-3.10-2.1-composer
+ container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
+ markers: not daily and (remote or not remote) and not gpu and not doctest
+ pytest_command: coverage run -m pytest
+ composer_package_name: composer
+ - name: cpu-3.11-2.2
+ container: mosaicml/pytorch:2.2.0_cpu-python3.11-ubuntu20.04
+ markers: not daily and (remote or not remote) and not gpu and not doctest
+ pytest_command: coverage run -m pytest
+ composer_package_name: mosaicml
+ - name: cpu-doctest
+ container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
+ markers: not daily and (remote or not remote) and not gpu and doctest
+ pytest_command: coverage run -m pytest tests/test_docs.py
+ composer_package_name: mosaicml
+ - name: daily-cpu-3.10-2.0
+ container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
+ markers: daily and (remote or not remote) and not gpu and not doctest
+ pytest_command: coverage run -m pytest
+ composer_package_name: mosaicml
+ - name: daily-cpu-3.10-2.1
+ container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
+ markers: daily and (remote or not remote) and not gpu and not doctest
+ pytest_command: coverage run -m pytest
+ composer_package_name: mosaicml
+ - name: daily-cpu-3.10-2.1-composer
+ container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
+ markers: daily and (remote or not remote) and not gpu and not doctest
+ pytest_command: coverage run -m pytest
+ composer_package_name: composer
+ - name: daily-cpu-3.11-2.2
+ container: mosaicml/pytorch:2.2.0_cpu-python3.11-ubuntu20.04
+ markers: daily and (remote or not remote) and not gpu and not doctest
+ pytest_command: coverage run -m pytest
+ composer_package_name: mosaicml
+ - name: daily-cpu-doctest
+ container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
+ markers: daily and (remote or not remote) and not gpu and doctest
+ pytest_command: coverage run -m pytest tests/test_docs.py
+ composer_package_name: mosaicml
name: ${{ matrix.name }}
if: github.repository_owner == 'mosaicml'
with:
@@ -86,24 +76,25 @@ jobs:
pytest-command: ${{ matrix.pytest_command }}
pytest-markers: ${{ matrix.markers }}
composer_package_name: ${{ matrix.composer_package_name }}
- pytest-s3-bucket: 'mosaicml-internal-integration-testing'
- pytest-wandb-entity: 'mosaicml-public-integration-tests'
+ pytest-s3-bucket: "mosaicml-internal-integration-testing"
+ pytest-wandb-entity: "mosaicml-public-integration-tests"
pytest-wandb-project: "integration-tests-${{ github.sha }}"
secrets:
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
wandb-api-key: ${{ secrets.WANDB_API_KEY }}
- slack-notifications-bot-token: ${{ secrets.SLACK_NOTIFICATIONS_BOT_TOKEN }}
code-eval-device: ${{ secrets.CODE_EVAL_DEVICE }}
code-eval-url: ${{ secrets.CODE_EVAL_URL }}
code-eval-apikey: ${{ secrets.CODE_EVAL_APIKEY }}
gcs-key: ${{ secrets.GCS_KEY }}
gcs-secret: ${{ secrets.GCS_SECRET }}
+ azure-account-name: ${{ secrets.AZURE_ACCOUNT_NAME }}
+ azure-account-access-key: ${{ secrets.AZURE_ACCOUNT_ACCESS_KEY }}
coverage:
uses: ./.github/workflows/coverage.yaml
name: Coverage Results
if: github.repository_owner == 'mosaicml'
- needs: [ daily-pytest-cpu ]
+ needs: [daily-pytest-cpu]
with:
download-path: artifacts
@@ -114,21 +105,21 @@ jobs:
# Unlike CPU tests, we run daily tests together with GPU tests to minimize launch time
# on MCLOUD and not eat up all GPUs at once
include:
- - name: 'gpu-3.10-1.13'
- container: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
- markers: '(daily or not daily) and (remote or not remote) and gpu and (doctest or not doctest)'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
- - name: 'gpu-3.10-2.0'
- container: mosaicml/pytorch_vision:2.0.1_cu117-python3.10-ubuntu20.04
- markers: '(daily or not daily) and (remote or not remote) and gpu and (doctest or not doctest)'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
- - name: 'gpu-3.10-2.1'
- container: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
- markers: '(daily or not daily) and (remote or not remote) and gpu and (doctest or not doctest)'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
+ - name: "gpu-3.10-2.0"
+ container: mosaicml/pytorch_vision:2.0.1_cu117-python3.10-ubuntu20.04
+ markers: "(daily or not daily) and (remote or not remote) and gpu and (doctest or not doctest)"
+ pytest_command: "coverage run -m pytest"
+ composer_package_name: "mosaicml"
+ - name: "gpu-3.10-2.1"
+ container: mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04
+ markers: "(daily or not daily) and (remote or not remote) and gpu and (doctest or not doctest)"
+ pytest_command: "coverage run -m pytest"
+ composer_package_name: "mosaicml"
+ - name: "gpu-3.10-2.2"
+ container: mosaicml/pytorch:2.2.0_cu121-python3.11-ubuntu20.04
+ markers: "(daily or not daily) and (remote or not remote) and gpu and (doctest or not doctest)"
+ pytest_command: "coverage run -m pytest"
+ composer_package_name: "mosaicml"
name: ${{ matrix.name }}
if: github.repository_owner == 'mosaicml'
with:
@@ -141,4 +132,3 @@ jobs:
python-version: 3.9
secrets:
mcloud-api-key: ${{ secrets.MCLOUD_DAILY_API_KEY }}
- slack-notifications-bot-token: ${{ secrets.SLACK_NOTIFICATIONS_BOT_TOKEN }}
diff --git a/.github/workflows/docker-configure-build-push.yaml b/.github/workflows/docker-configure-build-push.yaml
index 8ae2705700..2b6bf4893d 100644
--- a/.github/workflows/docker-configure-build-push.yaml
+++ b/.github/workflows/docker-configure-build-push.yaml
@@ -38,58 +38,58 @@ jobs:
configure-build-push:
runs-on: ubuntu-latest
steps:
- - name: Maximize Build Space on Worker
- uses: easimon/maximize-build-space@v4
- with:
- overprovision-lvm: true
- remove-dotnet: true
- remove-android: true
- remove-haskell: true
+ - name: Maximize Build Space on Worker
+ uses: easimon/maximize-build-space@v4
+ with:
+ overprovision-lvm: true
+ remove-dotnet: true
+ remove-android: true
+ remove-haskell: true
- - name: Checkout
- uses: actions/checkout@v3
+ - name: Checkout
+ uses: actions/checkout@v3
- - name: Setup QEMU
- uses: docker/setup-qemu-action@v2
+ - name: Setup QEMU
+ uses: docker/setup-qemu-action@v2
- - name: Setup Docker Buildx
- uses: docker/setup-buildx-action@v2
+ - name: Setup Docker Buildx
+ uses: docker/setup-buildx-action@v2
- - name: Login to DockerHub
- uses: docker/login-action@v2
- with:
- username: ${{ secrets.username }}
- password: ${{ secrets.password }}
+ - name: Login to DockerHub
+ uses: docker/login-action@v2
+ with:
+ username: ${{ secrets.username }}
+ password: ${{ secrets.password }}
- - name: Calculate Docker Image Variables
- run: |
- set -euo pipefail
+ - name: Calculate Docker Image Variables
+ run: |
+ set -euo pipefail
- ###################
- # Calculate the tag
- ###################
- if [ "${{ inputs.staging }}" = "true" ]; then
- STAGING_REPO=${{ inputs.staging-repo }}
- IMAGE_TAG=${STAGING_REPO}:${{ inputs.image-uuid }}
- IMAGE_CACHE="${STAGING_REPO}:${{ inputs.image-name }}-buildcache"
- else
- IMAGE_TAG=${{ inputs.tags }}
- IMAGE_CACHE="${IMAGE_TAG/,*/}-buildcache"
- fi
+ ###################
+ # Calculate the tag
+ ###################
+ if [ "${{ inputs.staging }}" = "true" ]; then
+ STAGING_REPO=${{ inputs.staging-repo }}
+ IMAGE_TAG=${STAGING_REPO}:${{ inputs.image-uuid }}
+ IMAGE_CACHE="${STAGING_REPO}:${{ inputs.image-name }}-buildcache"
+ else
+ IMAGE_TAG=${{ inputs.tags }}
+ IMAGE_CACHE="${IMAGE_TAG/,*/}-buildcache"
+ fi
- echo "IMAGE_TAG=${IMAGE_TAG}" >> ${GITHUB_ENV}
- echo "IMAGE_CACHE=${IMAGE_CACHE}" >> ${GITHUB_ENV}
+ echo "IMAGE_TAG=${IMAGE_TAG}" >> ${GITHUB_ENV}
+ echo "IMAGE_CACHE=${IMAGE_CACHE}" >> ${GITHUB_ENV}
- - name: IMAGE_TAG = ${{ env.IMAGE_TAG }}
- run: echo ${{ env.IMAGE_TAG }}
+ - name: IMAGE_TAG = ${{ env.IMAGE_TAG }}
+ run: echo ${{ env.IMAGE_TAG }}
- - name: Build and Push the Docker Image
- uses: docker/build-push-action@v3
- with:
- context: ${{ inputs.context }}
- tags: ${{ env.IMAGE_TAG }}
- target: ${{ inputs.target }}
- push: ${{ inputs.push }}
- cache-from: type=registry,ref=${{ env.IMAGE_CACHE }}
- cache-to: type=registry,ref=${{ env.IMAGE_CACHE }},mode=max
- build-args: ${{ inputs.build-args }}
+ - name: Build and Push the Docker Image
+ uses: docker/build-push-action@v3
+ with:
+ context: ${{ inputs.context }}
+ tags: ${{ env.IMAGE_TAG }}
+ target: ${{ inputs.target }}
+ push: ${{ inputs.push }}
+ cache-from: type=registry,ref=${{ env.IMAGE_CACHE }}
+ cache-to: type=registry,ref=${{ env.IMAGE_CACHE }},mode=max
+ build-args: ${{ inputs.build-args }}
diff --git a/.github/workflows/pr-code-quality.yaml b/.github/workflows/pr-code-quality.yaml
new file mode 100644
index 0000000000..26d2546e75
--- /dev/null
+++ b/.github/workflows/pr-code-quality.yaml
@@ -0,0 +1,28 @@
+name: PR Code Quality Checks
+on:
+ push:
+ branches:
+ - dev
+ - main
+ - release/**
+ pull_request:
+ workflow_dispatch:
+# Cancel old runs when a new commit is pushed to the same branch if not on main
+# or dev
+concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
+ cancel-in-progress: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/dev' }}
+jobs:
+ code-quality:
+ uses: ./.github/workflows/code-quality.yaml
+ strategy:
+ matrix:
+ python_version:
+ - "3.9"
+ - "3.10"
+ - "3.11"
+ pip_deps:
+ - "[dev]"
+ with:
+ python_version: ${{ matrix.python_version }}
+ pip_deps: ${{ matrix.pip_deps }}
diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml
index 989b4ded43..6eee54cb0b 100644
--- a/.github/workflows/pr-cpu.yaml
+++ b/.github/workflows/pr-cpu.yaml
@@ -2,7 +2,8 @@ name: PR CPU tests
on:
pull_request:
workflow_dispatch:
-# Cancel old runs when a new commit is pushed to the same branch if not on main or dev
+# Cancel old runs when a new commit is pushed to the same branch if not on main
+# or dev
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/dev' }}
@@ -12,31 +13,21 @@ jobs:
strategy:
matrix:
include:
- - name: 'cpu-3.10-1.13'
- container: mosaicml/pytorch:1.13.1_cpu-python3.10-ubuntu20.04
- markers: 'not daily and not remote and not gpu and not vision and not doctest'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
- - name: 'cpu-3.10-2.0'
- container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
- markers: 'not daily and not remote and not gpu and not vision and not doctest'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
- - name: 'cpu-3.10-2.1'
- container: mosaicml/pytorch:2.1.0_cpu-python3.10-ubuntu20.04
- markers: 'not daily and not remote and not gpu and not vision and not doctest'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
- - name: 'cpu-vision'
- container: mosaicml/pytorch_vision:1.13.1_cpu-python3.10-ubuntu20.04
- markers: 'not daily and not remote and not gpu and vision and not doctest'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
- - name: 'cpu-doctest'
- container: mosaicml/pytorch:1.13.1_cpu-python3.10-ubuntu20.04
- markers: 'not daily and not remote and not gpu and not vision and doctest'
- pytest_command: 'coverage run -m pytest tests/test_docs.py'
- composer_package_name: 'mosaicml'
+ - name: cpu-3.10-2.0
+ container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
+ markers: not daily and not remote and not gpu and not doctest
+ pytest_command: coverage run -m pytest
+ composer_package_name: mosaicml
+ - name: cpu-3.10-2.1
+ container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
+ markers: not daily and not remote and not gpu and not doctest
+ pytest_command: coverage run -m pytest
+ composer_package_name: mosaicml
+ - name: cpu-doctest
+ container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
+ markers: not daily and not remote and not gpu and doctest
+ pytest_command: coverage run -m pytest tests/test_docs.py
+ composer_package_name: mosaicml
name: ${{ matrix.name }}
if: github.repository_owner == 'mosaicml'
with:
@@ -49,6 +40,6 @@ jobs:
uses: ./.github/workflows/coverage.yaml
name: Coverage Results
if: github.repository_owner == 'mosaicml'
- needs: [ pytest-cpu ]
+ needs: [pytest-cpu]
with:
download-path: artifacts
diff --git a/.github/workflows/pr-docker.yaml b/.github/workflows/pr-docker.yaml
index 52e009e68f..93f0b51be1 100644
--- a/.github/workflows/pr-docker.yaml
+++ b/.github/workflows/pr-docker.yaml
@@ -2,14 +2,14 @@ name: PR Docker
on:
pull_request:
branches:
- - dev
- - main
- - release/**
+ - dev
+ - main
+ - release/**
paths:
- - .github/bin/gen_docker_matrix.py
- - .github/workflows/docker-configure-build-push.yaml
- - .github/workflows/pr-docker.yaml
- - docker/**
+ - .github/bin/gen_docker_matrix.py
+ - .github/workflows/docker-configure-build-push.yaml
+ - .github/workflows/pr-docker.yaml
+ - docker/**
workflow_dispatch:
defaults:
run:
@@ -22,21 +22,23 @@ jobs:
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- - uses: actions/setup-python@v4
- with:
- python-version: 3.9
- - uses: actions/checkout@v3
- - id: set-matrix
- run: |
- # Install yaml dependency
- pip install pyyaml
+ - uses: actions/setup-python@v4
+ with:
+ python-version: 3.9
+ - uses: actions/checkout@v3
+ - id: set-matrix
+ run: |
+ # Install yaml dependency
+ pip install pyyaml
- # Override package install command for Composer image
- COMPOSER_INSTALL_COMMAND="mosaicml[all]@git+https://github.com/mosaicml/composer.git@${{ github.sha }}"
+ # Override package install command for Composer image
+ COMPOSER_INSTALL_COMMAND="mosaicml[all]@git+https://github.com/mosaicml/composer.git@${{ github.sha }}"
- # Generate build matrix
- BUILD_MATRIX=$(python .github/bin/gen_docker_matrix.py docker/build_matrix.yaml -b COMPOSER_INSTALL_COMMAND=$COMPOSER_INSTALL_COMMAND)
- echo $BUILD_MATRIX >> $GITHUB_OUTPUT
+ # Generate build matrix
+ BUILD_MATRIX=$(python .github/bin/gen_docker_matrix.py docker/build_matrix.yaml -b \
+ COMPOSER_INSTALL_COMMAND=$COMPOSER_INSTALL_COMMAND)
+
+ echo $BUILD_MATRIX >> $GITHUB_OUTPUT
stage-docker-build:
needs: build-image-matrix
uses: ./.github/workflows/docker-configure-build-push.yaml
diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml
index 2c818b7229..1b02fc9c51 100644
--- a/.github/workflows/pr-gpu.yaml
+++ b/.github/workflows/pr-gpu.yaml
@@ -2,7 +2,8 @@ name: PR GPU tests
on:
pull_request_target:
workflow_dispatch:
-# Cancel old runs when a new commit is pushed to the same branch if not on main or dev
+# Cancel old runs when a new commit is pushed to the same branch if not on main
+# or dev
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/dev' }}
@@ -12,11 +13,11 @@ jobs:
strategy:
matrix:
include:
- - name: 'gpu-3.10-2.1'
- container: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
- markers: 'not daily and not remote and gpu and (doctest or not doctest)'
- pytest_command: 'coverage run -m pytest'
- composer_package_name: 'mosaicml'
+ - name: gpu-3.10-2.1
+ container: mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04
+ markers: not daily and not remote and gpu and (doctest or not doctest)
+ pytest_command: coverage run -m pytest
+ composer_package_name: mosaicml
name: ${{ matrix.name }}
if: github.repository_owner == 'mosaicml'
with:
diff --git a/.github/workflows/pytest-cpu.yaml b/.github/workflows/pytest-cpu.yaml
index 152501ad64..af95f8918f 100644
--- a/.github/workflows/pytest-cpu.yaml
+++ b/.github/workflows/pytest-cpu.yaml
@@ -45,55 +45,53 @@ on:
required: false
gcs-secret:
required: false
+ azure-account-name:
+ required: false
+ azure-account-access-key:
+ required: false
jobs:
pytest-cpu:
timeout-minutes: 30
runs-on: ubuntu-latest
container: ${{ inputs.container }}
steps:
- - name: Checkout Repo
- uses: actions/checkout@v3
- - name: Setup
- run: |
- set -ex
- export PATH=/composer-python:$PATH
- export COMPOSER_PACKAGE_NAME='${{ inputs.composer_package_name }}'
- python -m pip install --upgrade 'pip<23' wheel
- python -m pip install --upgrade .[all]
- - name: Run Tests
- id: tests
- run: |
- set -ex
- export PATH=/composer-python:$PATH
- export WANDB_API_KEY='${{ secrets.wandb-api-key }}'
- export WANDB_ENTITY='${{ inputs.pytest-wandb-entity }}'
- export WANDB_PROJECT='${{ inputs.pytest-wandb-project }}'
- export AWS_ACCESS_KEY_ID='${{ secrets.aws-access-key-id }}'
- export AWS_SECRET_ACCESS_KEY='${{ secrets.aws-secret-access-key }}'
- export CODE_EVAL_DEVICE='${{ secrets.code-eval-device }}'
- export CODE_EVAL_URL='${{ secrets.code-eval-url }}'
- export CODE_EVAL_APIKEY='${{ secrets.code-eval-apikey }}'
- export GCS_KEY='${{ secrets.gcs-key }}'
- export GCS_SECRET='${{ secrets.gcs-secret }}'
- export S3_BUCKET='${{ inputs.pytest-s3-bucket }}'
- export COMMON_ARGS="-v --durations=20 -m '${{ inputs.pytest-markers }}' --s3_bucket '$S3_BUCKET' -o tmp_path_retention_policy=none"
+ - name: Checkout Repo
+ uses: actions/checkout@v3
+ - name: Setup
+ run: |
+ set -ex
+ export PATH=/composer-python:$PATH
+ export COMPOSER_PACKAGE_NAME='${{ inputs.composer_package_name }}'
+ python -m pip install --upgrade 'pip<23' wheel
+ python -m pip install --upgrade .[all]
+ - name: Run Tests
+ id: tests
+ run: |
+ set -ex
+ export PATH=/composer-python:$PATH
+ export WANDB_API_KEY='${{ secrets.wandb-api-key }}'
+ export WANDB_ENTITY='${{ inputs.pytest-wandb-entity }}'
+ export WANDB_PROJECT='${{ inputs.pytest-wandb-project }}'
+ export AWS_ACCESS_KEY_ID='${{ secrets.aws-access-key-id }}'
+ export AWS_SECRET_ACCESS_KEY='${{ secrets.aws-secret-access-key }}'
+ export CODE_EVAL_DEVICE='${{ secrets.code-eval-device }}'
+ export CODE_EVAL_URL='${{ secrets.code-eval-url }}'
+ export CODE_EVAL_APIKEY='${{ secrets.code-eval-apikey }}'
+ export GCS_KEY='${{ secrets.gcs-key }}'
+ export GCS_SECRET='${{ secrets.gcs-secret }}'
+ export AZURE_ACCOUNT_NAME='${{ secrets.azure-account-name }}'
+ export AZURE_ACCOUNT_ACCESS_KEY='${{ secrets.azure-account-access-key }}'
+ export S3_BUCKET='${{ inputs.pytest-s3-bucket }}'
+ export COMMON_ARGS="-v --durations=20 -m '${{ inputs.pytest-markers }}' --s3_bucket '$S3_BUCKET' \
+ -o tmp_path_retention_policy=none"
- # Necessary to run git diff for doctests
- git config --global --add safe.directory /__w/composer/composer
- make test PYTEST='${{ inputs.pytest-command }}' EXTRA_ARGS="$COMMON_ARGS --codeblocks"
- make test-dist PYTEST='${{ inputs.pytest-command }}' EXTRA_ARGS="$COMMON_ARGS" WORLD_SIZE=2
+ # Necessary to run git diff for doctests
+ git config --global --add safe.directory /__w/composer/composer
+ make test PYTEST='${{ inputs.pytest-command }}' EXTRA_ARGS="$COMMON_ARGS --codeblocks"
+ make test-dist PYTEST='${{ inputs.pytest-command }}' EXTRA_ARGS="$COMMON_ARGS" WORLD_SIZE=2
- python -m coverage combine
- - uses: actions/upload-artifact@v3
- with:
- name: coverage-${{ github.sha }}-${{ inputs.name }}
- path: .coverage
- - name: Notify slack fail
- if: failure() && !cancelled() && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev') && (github.event_name != 'pull_request' && github.event_name != 'pull_request_target')
- env:
- SLACK_BOT_TOKEN: ${{ secrets.slack-notifications-bot-token }}
- uses: voxmedia/github-action-slack-notify-build@v1
- with:
- channel: composer-issues
- status: FAILED
- color: danger
+ python -m coverage combine
+ - uses: actions/upload-artifact@v3
+ with:
+ name: coverage-${{ github.sha }}-${{ inputs.name }}
+ path: .coverage
diff --git a/.github/workflows/pytest-gpu.yaml b/.github/workflows/pytest-gpu.yaml
index 100de255e8..550a306746 100644
--- a/.github/workflows/pytest-gpu.yaml
+++ b/.github/workflows/pytest-gpu.yaml
@@ -38,58 +38,52 @@ on:
required: false
jobs:
pytest-gpu:
- timeout-minutes: 60 # ${{ inputs.gha-timeout }} for some reason not able to turn this into an input
+ timeout-minutes: 60 # ${{ inputs.gha-timeout }} for some reason not able to turn this into an input
runs-on: ubuntu-latest
env:
MOSAICML_API_KEY: ${{ secrets.mcloud-api-key }}
steps:
- - name: Checkout Repo
- uses: actions/checkout@v3
- - name: Setup Python
- uses: actions/setup-python@v4
- with:
- python-version: ${{ inputs.python-version }}
- - name: Cache pip
- uses: actions/cache@v3
- with:
- # This path is specific to Ubuntu
- path: ~/.cache/pip
- # Look to see if there is a cache hit for the corresponding requirements file
- key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
- restore-keys: |
- ${{ runner.os }}-pip-
- ${{ runner.os }}-
- - name: Setup MCLI
- run: |
- set -ex
- python -m pip install mosaicml-cli
- mcli version
- - name: Submit Run
- id: tests
- run: |
- set -ex
+ - name: Checkout Repo
+ uses: actions/checkout@v3
+ - name: Setup Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ inputs.python-version }}
+ - name: Cache pip
+ uses: actions/cache@v3
+ with:
+ # This path is specific to Ubuntu
+ path: ~/.cache/pip
+ # Look to see if there is a cache hit for the corresponding requirements file
+ key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
+ restore-keys: |
+ ${{ runner.os }}-pip-
+ ${{ runner.os }}-
+ - name: Setup MCLI
+ run: |
+ set -ex
+ python -m pip install mosaicml-cli
+ mcli version
+ - name: Submit Run
+ id: tests
+ run: |
+ set -ex
- PR_NUMBER="$(jq --raw-output .pull_request.number "$GITHUB_EVENT_PATH")"
- REF_ARGS=""
+ PR_NUMBER="$(jq --raw-output .pull_request.number "$GITHUB_EVENT_PATH")"
+ REF_ARGS=""
- # Use the PR number if it exists, commit SHA for protected branches and the branch name otherwise
- if [ -z "$PR_NUMBER" ] || [ "$PR_NUMBER" = "null" ]; then
- if [[ "$GITHUB_REF" =~ "refs/heads/dev" || "$GITHUB_REF" =~ "refs/heads/main" || "$GITHUB_REF" =~ "refs/heads/release" ]]; then
- REF_ARGS="--git_commit $GITHUB_SHA"
- else
- REF_ARGS="--git_branch $GITHUB_REF_NAME"
- fi
+ # Use the PR number if it exists, commit SHA for protected branches and the branch name otherwise
+ if [ -z "$PR_NUMBER" ] || [ "$PR_NUMBER" = "null" ]; then
+ if [[ "$GITHUB_REF" =~ "refs/heads/dev" || "$GITHUB_REF" =~ "refs/heads/main" || \
+ "$GITHUB_REF" =~ "refs/heads/release" ]]; then
+ REF_ARGS="--git_commit $GITHUB_SHA"
else
- REF_ARGS="--pr_number $PR_NUMBER"
+ REF_ARGS="--git_branch $GITHUB_REF_NAME"
fi
+ else
+ REF_ARGS="--pr_number $PR_NUMBER"
+ fi
- python .github/mcli/mcli_pytest.py --image '${{ inputs.container }}' --pip_package_name '${{ inputs.composer_package_name }}' --pytest_markers '${{ inputs.pytest-markers }}' --pytest_command '${{ inputs.pytest-command }}' --timeout ${{ inputs.mcloud-timeout }} ${REF_ARGS}
- - name: Notify slack fail
- if: failure() && !cancelled() && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev') && (github.event_name != 'pull_request' && github.event_name != 'pull_request_target')
- env:
- SLACK_BOT_TOKEN: ${{ secrets.slack-notifications-bot-token }}
- uses: voxmedia/github-action-slack-notify-build@v1
- with:
- channel: composer-issues
- status: FAILED
- color: danger
+ python .github/mcli/mcli_pytest.py --image '${{ inputs.container }}' --pip_package_name \
+ '${{ inputs.composer_package_name }}' --pytest_markers '${{ inputs.pytest-markers }}' --pytest_command \
+ '${{ inputs.pytest-command }}' --timeout ${{ inputs.mcloud-timeout }} ${REF_ARGS}
diff --git a/.github/workflows/release-docker.yaml b/.github/workflows/release-docker.yaml
index 17a718021e..e992663994 100644
--- a/.github/workflows/release-docker.yaml
+++ b/.github/workflows/release-docker.yaml
@@ -21,18 +21,18 @@ jobs:
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- - uses: actions/setup-python@v4
- with:
- python-version: 3.9
- - uses: actions/checkout@v3
- - id: set-matrix
- run: |
- # Install yaml dependency
- pip install pyyaml
+ - uses: actions/setup-python@v4
+ with:
+ python-version: 3.9
+ - uses: actions/checkout@v3
+ - id: set-matrix
+ run: |
+ # Install yaml dependency
+ pip install pyyaml
- # Generate build matrix
- BUILD_MATRIX=$(python .github/bin/gen_docker_matrix.py docker/build_matrix.yaml)
- echo $BUILD_MATRIX >> $GITHUB_OUTPUT
+ # Generate build matrix
+ BUILD_MATRIX=$(python .github/bin/gen_docker_matrix.py docker/build_matrix.yaml)
+ echo $BUILD_MATRIX >> $GITHUB_OUTPUT
stage-docker-build:
needs: build-image-matrix
diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index 05d78b4832..50032973ca 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -3,88 +3,99 @@ name: Release
on:
push:
tags:
- - "v*"
+ - "v*"
workflow_dispatch:
jobs:
code-quality:
uses: ./.github/workflows/code-quality.yaml
+ strategy:
+ matrix:
+ python_version:
+ - "3.9"
+ - "3.10"
+ - "3.11"
+ pip_deps:
+ - "[dev]"
+ with:
+ python_version: ${{ matrix.python_version }}
+ pip_deps: ${{ matrix.pip_deps }}
pypi-packaging:
name: Build and Publish mosaicml PyPI Package
needs:
- - code-quality
+ - code-quality
runs-on: ubuntu-latest
steps:
- - name: Checkout source
- uses: actions/checkout@v3
+ - name: Checkout source
+ uses: actions/checkout@v3
- - name: Set up Python
- uses: actions/setup-python@v3
- with:
- python-version: "3.9"
+ - name: Set up Python
+ uses: actions/setup-python@v3
+ with:
+ python-version: "3.9"
- - name: Build source and wheel distributions
- run: |
- if [[ "${{ github.ref }}" =~ refs\/tags\/v ]]; then
- PYPI_PACKAGE_NAME="mosaicml"
- else
- PYPI_PACKAGE_NAME="mosaicml-test-$(date +%Y%m%d%H%M%S)"
- fi
+ - name: Build source and wheel distributions
+ run: |
+ if [[ "${{ github.ref }}" =~ refs\/tags\/v ]]; then
+ PYPI_PACKAGE_NAME="mosaicml"
+ else
+ PYPI_PACKAGE_NAME="mosaicml-test-$(date +%Y%m%d%H%M%S)"
+ fi
- python -m pip install --upgrade build twine
- COMPOSER_PACKAGE_NAME=$PYPI_PACKAGE_NAME python -m build
- twine check --strict dist/*
+ python -m pip install --upgrade build twine
+ COMPOSER_PACKAGE_NAME=$PYPI_PACKAGE_NAME python -m build
+ twine check --strict dist/*
- - name: Publish đĻ to PyPI
- uses: pypa/gh-action-pypi-publish@release/v1
- if: contains(github.ref, 'refs/tags/v')
- with:
- user: __token__
- password: ${{ secrets.PROD_PYPI_API_TOKEN }}
+ - name: Publish đĻ to PyPI
+ uses: pypa/gh-action-pypi-publish@release/v1
+ if: contains(github.ref, 'refs/tags/v')
+ with:
+ user: __token__
+ password: ${{ secrets.PROD_PYPI_API_TOKEN }}
- - name: Publish distribution đĻ to Test PyPI
- uses: pypa/gh-action-pypi-publish@release/v1
- if: contains(github.ref, 'refs/heads/') || contains(github.ref, 'refs/pull/')
- with:
- user: __token__
- password: ${{ secrets.TEST_PYPI_API_TOKEN }}
- repository_url: https://test.pypi.org/legacy/
+ - name: Publish distribution đĻ to Test PyPI
+ uses: pypa/gh-action-pypi-publish@release/v1
+ if: contains(github.ref, 'refs/heads/') || contains(github.ref, 'refs/pull/')
+ with:
+ user: __token__
+ password: ${{ secrets.TEST_PYPI_API_TOKEN }}
+ repository_url: https://test.pypi.org/legacy/
pypi-composer-packaging:
name: Build and Publish composer PyPI Package
needs:
- - code-quality
+ - code-quality
if: contains(github.ref, 'refs/tags/v')
runs-on: ubuntu-latest
steps:
- - name: Checkout source
- uses: actions/checkout@v3
+ - name: Checkout source
+ uses: actions/checkout@v3
- - name: Set up Python
- uses: actions/setup-python@v3
- with:
- python-version: "3.9"
+ - name: Set up Python
+ uses: actions/setup-python@v3
+ with:
+ python-version: "3.9"
- - name: Build source and wheel distributions
- run: |
- PYPI_PACKAGE_NAME="composer"
+ - name: Build source and wheel distributions
+ run: |
+ PYPI_PACKAGE_NAME="composer"
- python -m pip install --upgrade build twine
- COMPOSER_PACKAGE_NAME=$PYPI_PACKAGE_NAME python -m build
- twine check --strict dist/*
+ python -m pip install --upgrade build twine
+ COMPOSER_PACKAGE_NAME=$PYPI_PACKAGE_NAME python -m build
+ twine check --strict dist/*
- - name: Publish đĻ to PyPI
- uses: pypa/gh-action-pypi-publish@release/v1
- with:
- user: __token__
- password: ${{ secrets.PROD_COMPOSER_PYPI_API_TOKEN }}
+ - name: Publish đĻ to PyPI
+ uses: pypa/gh-action-pypi-publish@release/v1
+ with:
+ user: __token__
+ password: ${{ secrets.PROD_COMPOSER_PYPI_API_TOKEN }}
production-docker-images:
name: Build and Push Production Docker Images
needs:
- - pypi-packaging
- - pypi-composer-packaging
+ - pypi-packaging
+ - pypi-composer-packaging
uses: ./.github/workflows/release-docker.yaml
if: contains(github.ref, 'refs/tags/v')
secrets:
diff --git a/.github/workflows/smoketest.yaml b/.github/workflows/smoketest.yaml
index 00121f935b..429cc40b1d 100644
--- a/.github/workflows/smoketest.yaml
+++ b/.github/workflows/smoketest.yaml
@@ -2,13 +2,14 @@ name: Smoketest
on:
push:
branches:
- - dev
- - main
- - release/**
+ - dev
+ - main
+ - release/**
pull_request:
workflow_call:
workflow_dispatch:
-# Cancel old runs when a new commit is pushed to the same branch if not on main or dev
+# Cancel old runs when a new commit is pushed to the same branch if not on main
+# or dev
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/dev' }}
@@ -22,20 +23,20 @@ jobs:
strategy:
matrix:
python_version:
- - "3.8"
- - "3.9"
- - "3.10"
+ - "3.9"
+ - "3.10"
+ - "3.11"
steps:
- - uses: actions/checkout@v3
- - uses: actions/setup-python@v4
- with:
- python-version: ${{ matrix.python_version }}
- - name: Setup
- run: |
- set -ex
- python -m pip install --upgrade 'pip<23' wheel
- python -m pip install --upgrade .
- python -m pip install pytest==7.2.1 pytest_codeblocks==0.16.1
- - name: Run checks
- run: |
- pytest tests/test_smoketest.py
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python_version }}
+ - name: Setup
+ run: |
+ set -ex
+ python -m pip install --upgrade 'pip<23' wheel
+ python -m pip install --upgrade .
+ python -m pip install pytest==7.2.1 pytest_codeblocks==0.16.1
+ - name: Run checks
+ run: |
+ pytest tests/test_smoketest.py
diff --git a/.gitignore b/.gitignore
index 9b66bf52db..789c75183b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -136,6 +136,9 @@ venv/
# WandB
wandb/
+# Neptune
+.neptune/
+
# Spacemacs
._#*
.#*
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index d0f8595580..4f89154571 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,125 +1,135 @@
default_language_version:
python: python3
repos:
- - repo: https://github.com/astral-sh/ruff-pre-commit
- # Ruff version.
- rev: v0.0.282
- hooks:
- - id: ruff
- args: [--fix, --exit-non-zero-on-fix]
+- repo: https://github.com/astral-sh/ruff-pre-commit
+ # Ruff version.
+ rev: v0.0.282
+ hooks:
+ - id: ruff
+ args: [--fix, --exit-non-zero-on-fix]
- - repo: https://github.com/google/yapf
- rev: v0.32.0
- hooks:
- - id: yapf
- name: yapf
- description: "A formatter for Python files."
- entry: yapf
- args: [-i, -vv, -p] #inplace
- language: python
- types: [python]
- additional_dependencies:
- - "toml"
- - repo: https://github.com/pycqa/isort
- hooks:
- - id: isort
- rev: 5.12.0
- # - repo: https://github.com/pycqa/pylint
- # hooks:
- # - id: pylint
- # entry: pylint
- # args: ['composer', 'examples', 'tests']
- # language: python
- # types: [python]
- # require_serial: true
- # rev: v2.12.2
- - repo: https://github.com/PyCQA/pydocstyle
- hooks:
- - id: pydocstyle
- name: pydocstyle
- entry: pydocstyle
- language: python
- types: [python]
- exclude: '(?:tests|.ci|composer\/algorithms|composer\/datasets|composer\/models)\/.*|composer\/trainer\/activation_checkpointing.py'
- additional_dependencies:
- - "toml"
- rev: 6.1.1
- - repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.1.0
- hooks:
- - id: check-added-large-files
- - id: check-ast
- - id: check-builtin-literals
- - id: check-case-conflict
- - id: check-docstring-first
- - id: check-executables-have-shebangs
- - id: check-json
- - id: check-shebang-scripts-are-executable
- - id: pretty-format-json
- args:
- - --autofix
- - --no-sort-keys
- - --indent=1
- - --no-ensure-ascii
- - id: check-merge-conflict
- - id: check-symlinks
- - id: check-toml
- - id: check-vcs-permalinks
- - id: check-xml
- - id: check-yaml
- - id: debug-statements
- - id: destroyed-symlinks
- - id: double-quote-string-fixer
- - id: end-of-file-fixer
- - id: fix-byte-order-marker
- - id: mixed-line-ending
- - id: trailing-whitespace
- - repo: https://github.com/Lucas-C/pre-commit-hooks
- rev: v1.1.13
- hooks:
- - id: insert-license
- args:
- - --license-filepath
- - .ci/FILE_HEADER
- - --comment-style
- - "#"
- types: [python]
- exclude: 'composer\/trainer\/activation_checkpointing.py'
+- repo: https://github.com/google/yapf
+ rev: v0.32.0
+ hooks:
+ - id: yapf
+ name: yapf
+ description: "A formatter for Python files."
+ entry: yapf
+ args: [-i, -vv, -p] # inplace
+ language: python
+ types: [python]
+ additional_dependencies:
+ - "toml"
+- repo: https://github.com/pycqa/isort
+ hooks:
+ - id: isort
+ rev: 5.12.0
+# - repo: https://github.com/pycqa/pylint
+# hooks:
+# - id: pylint
+# entry: pylint
+# args: ['composer', 'examples', 'tests']
+# language: python
+# types: [python]
+# require_serial: true
+# rev: v2.12.2
+- repo: https://github.com/PyCQA/pydocstyle
+ hooks:
+ - id: pydocstyle
+ name: pydocstyle
+ entry: pydocstyle
+ language: python
+ types: [python]
+ exclude: "(?:tests|.ci|composer\/algorithms|composer\/datasets|composer\/models)\/.*|composer\/trainer\/activation_checkpointing.py"
+ additional_dependencies:
+ - "toml"
+ rev: 6.1.1
+- repo: https://github.com/adrienverge/yamllint.git
+ rev: v1.28.0
+ hooks:
+ - id: yamllint
+ name: yamllint
+ description: This hook runs yamllint.
+ entry: yamllint
+ language: python
+ types: [file, yaml]
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.1.0
+ hooks:
+ - id: check-added-large-files
+ - id: check-ast
+ - id: check-builtin-literals
+ - id: check-case-conflict
+ - id: check-docstring-first
+ - id: check-executables-have-shebangs
+ - id: check-json
+ - id: check-shebang-scripts-are-executable
+ - id: pretty-format-json
+ args:
+ - --autofix
+ - --no-sort-keys
+ - --indent=1
+ - --no-ensure-ascii
+ - id: check-merge-conflict
+ - id: check-symlinks
+ - id: check-toml
+ - id: check-vcs-permalinks
+ - id: check-xml
+ - id: check-yaml
+ - id: debug-statements
+ - id: destroyed-symlinks
+ - id: double-quote-string-fixer
+ - id: end-of-file-fixer
+ - id: fix-byte-order-marker
+ - id: mixed-line-ending
+ - id: trailing-whitespace
+- repo: https://github.com/Lucas-C/pre-commit-hooks
+ rev: v1.5.4
+ hooks:
+ - id: insert-license
+ args:
+ - --license-filepath
+ - .pre-commit/FILE_HEADER
+ - --comment-style
+ - "#"
+ - --allow-past-years
+ types: [python]
+ exclude: "composer\/trainer\/activation_checkpointing.py"
- - repo: https://github.com/kynan/nbstripout
- rev: 0.5.0
- hooks:
- - id: nbstripout
- types:
- - "jupyter"
- args:
- # Strip all the metadata that vscode or colab may add to a notebook
- - --strip-empty-cells
- - --extra-keys
- - >
- metadata.colab metadata.interpreter metadata.accelerator
- metadata.kernelspec metadata.language_info.version
- cell.metadata.heading_collapsed metadata.name metadata.nbconvert_exporter
- metadata.version metadata.vscode
- - repo: local
- hooks:
- - id: pyright
- name: pyright
- entry: pyright
- language: node
- types: [python]
- pass_filenames: false
- args: [--warnings]
- additional_dependencies: ["pyright@1.1.256"]
- - repo: https://github.com/trufflesecurity/trufflehog.git
- rev: v3.40.0
- hooks:
- - id: trufflehog
- name: secret scan
- entry: trufflehog filesystem ./
- args:
- - --only-verified
- - --fail
- - --exclude-paths=./.github/secrets/exclude.yaml
+- repo: https://github.com/kynan/nbstripout
+ rev: 0.5.0
+ hooks:
+ - id: nbstripout
+ types:
+ - "jupyter"
+ args:
+ # Strip all the metadata that vscode or colab may add to a notebook
+ - --strip-empty-cells
+ - --extra-keys
+ - >
+ metadata.colab metadata.interpreter metadata.accelerator
+ metadata.kernelspec metadata.language_info.version
+ cell.metadata.heading_collapsed metadata.name metadata.nbconvert_exporter
+ metadata.version metadata.vscode
+- repo: local
+ hooks:
+ - id: pyright
+ name: pyright
+ entry: pyright
+ language: node
+ types: [python]
+ pass_filenames: false
+ args: [--warnings]
+ additional_dependencies: ["pyright@1.1.310"]
+- repo: https://github.com/trufflesecurity/trufflehog.git
+ rev: v3.40.0
+ hooks:
+ - id: trufflehog
+ name: secret scan
+ entry: trufflehog filesystem ./
+ args:
+ - --only-verified
+ - --fail
+ - --exclude-paths=./.github/secrets/exclude.yaml
exclude: .ci\/release_tests\/.*
diff --git a/.pre-commit/FILE_HEADER b/.pre-commit/FILE_HEADER
new file mode 100644
index 0000000000..a37a1c719f
--- /dev/null
+++ b/.pre-commit/FILE_HEADER
@@ -0,0 +1,2 @@
+Copyright 2024 MosaicML Composer authors
+SPDX-License-Identifier: Apache-2.0
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index 6ebcd8535e..f8a390536d 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -7,17 +7,17 @@ version: 2
# Specify build system and tool dependencies
build:
- os: "ubuntu-20.04"
- tools:
- python: "3.8"
+ os: "ubuntu-20.04"
+ tools:
+ python: "3.10"
# Build documentation in the docs/ directory with Sphinx
sphinx:
- builder: html
- configuration: docs/source/conf.py
+ builder: html
+ configuration: docs/source/conf.py
# Optionally set the version of Python and requirements required to build your docs
python:
- install:
- - method: pip
- path: .[all]
+ install:
+ - method: pip
+ path: .[all]
diff --git a/.yamllint.yaml b/.yamllint.yaml
index 3e760e6ae0..8d8617a1cc 100644
--- a/.yamllint.yaml
+++ b/.yamllint.yaml
@@ -1,11 +1,11 @@
yaml-files:
- - "*.yaml"
- - "*.yml"
- - .yamllint
+- "*.yaml"
+- "*.yml"
+- .yamllint
ignore: |
wandb
- *
+ docker/build_matrix.yaml
rules:
braces:
@@ -14,8 +14,7 @@ rules:
forbid: false
colons: enable
commas: enable
- comments:
- level: warning
+ comments: enable
comments-indentation: enable
document-end:
present: false
@@ -26,12 +25,12 @@ rules:
hyphens: enable
indentation:
spaces: 2
- indent-sequences: true
+ indent-sequences: false
check-multi-line-strings: false
key-duplicates: enable
key-ordering: disable
line-length:
- max: 200
+ max: 120
allow-non-breakable-words: true
allow-non-breakable-inline-mappings: true
new-line-at-end-of-file: enable
diff --git a/README.md b/README.md
index 17a6e41cfd..9ab992be3a 100644
--- a/README.md
+++ b/README.md
@@ -105,7 +105,7 @@ Composer is built to automate away low-level pain points and headaches so you ca
Integrate with the tools you know and love for experiment tracking and data streaming.
- **Cloud integrations**: Our Checkpointing and logging features have first-class support for remote storage and loading from Cloud bucket (OCI, GCP, AWS S3).
-- **********Experiment tracking:********** Weights and Biases, MLFlow, and CometML â the choice is yours, easily log your data to your favorite platform.
+- **********Experiment tracking:********** Weights and Biases, MLFlow, CometML, and neptune.ai â the choice is yours, easily log your data to your favorite platform.
# **đ Getting Started**
@@ -135,26 +135,55 @@ Here is a code snippet demonstrating our Trainer on the MNIST dataset.
```python
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from composer import Trainer
-from composer.models import mnist_model
+from composer.models import ComposerClassifier
from composer.algorithms import LabelSmoothing, CutMix, ChannelsLast
+class Model(nn.Module):
+ """Toy convolutional neural network architecture in pytorch for MNIST."""
+
+ def __init__(self, num_classes: int = 10):
+ super().__init__()
+
+ self.num_classes = num_classes
+
+ self.conv1 = nn.Conv2d(1, 16, (3, 3), padding=0)
+ self.conv2 = nn.Conv2d(16, 32, (3, 3), padding=0)
+ self.bn = nn.BatchNorm2d(32)
+ self.fc1 = nn.Linear(32 * 16, 32)
+ self.fc2 = nn.Linear(32, num_classes)
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = F.relu(out)
+ out = self.conv2(out)
+ out = self.bn(out)
+ out = F.relu(out)
+ out = F.adaptive_avg_pool2d(out, (4, 4))
+ out = torch.flatten(out, 1, -1)
+ out = self.fc1(out)
+ out = F.relu(out)
+ return self.fc2(out)
+
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST("data", train=True, download=True, transform=transform)
train_dataloader = DataLoader(dataset, batch_size=128)
trainer = Trainer(
- model=mnist_model(num_classes=10),
+ model=ComposerClassifier(module=Model(), num_classes=10),
train_dataloader=train_dataloader,
max_duration="2ep",
algorithms=[
LabelSmoothing(smoothing=0.1),
CutMix(alpha=1.0),
ChannelsLast(),
- ]
+ ],
)
trainer.fit()
```
diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md
index 274c10ce9c..4943a9db58 100644
--- a/STYLE_GUIDE.md
+++ b/STYLE_GUIDE.md
@@ -227,22 +227,23 @@ All imports in composer should be absolute -- that is, they do not begin with a
1. If a dependency is not core to Composer (e.g. it is for a model, dataset, algorithm, or some callbacks):
1. It must be specified in a entry of the `extra_deps` dictionary of [setup.py](setup.py).
This dictionary groups dependencies that can be conditionally installed. An entry named `foo`
- can be installed with `pip install 'mosaicml[foo]'`. For example, running `pip install 'mosaicml[unet]'`
- will install everything in `install_requires`, along with `monai` and `scikit-learn`.
+ can be installed with `pip install 'mosaicml[foo]'`. For example, running `pip install 'mosaicml[system_metrics_monitor]'`
+ will install everything in `install_requires`, along with `pynvml`.
1. It must also be specified in the `run_constrained` and the `test.requires` section.
1. The import must be conditionally imported in the code. For example:
```python
+ from composer import Callback
from composer.utils import MissingConditionalImportError
- def unet():
+ class SystemMetricsMonitor(Callback)
try:
- import monai
+ import pynvml
except ImportError as e:
- raise MissingConditionalImportError(extra_deps_group="unet",
- conda_package="monai",
+ raise MissingConditionalImportError(extra_deps_group="system_metrics_monitor",
+ conda_package="pynvml",
conda_channel="conda-forge",) from e
```
diff --git a/composer/_version.py b/composer/_version.py
index a41361e246..6a46c95e08 100644
--- a/composer/_version.py
+++ b/composer/_version.py
@@ -3,4 +3,4 @@
"""The Composer Version."""
-__version__ = '0.17.2'
+__version__ = '0.19.1'
diff --git a/composer/algorithms/alibi/attention_surgery_functions/__init__.py b/composer/algorithms/alibi/attention_surgery_functions/__init__.py
index cb27f89f2d..207f958b58 100644
--- a/composer/algorithms/alibi/attention_surgery_functions/__init__.py
+++ b/composer/algorithms/alibi/attention_surgery_functions/__init__.py
@@ -6,7 +6,8 @@
from composer.utils import MissingConditionalImportError
try:
- from composer.algorithms.alibi.attention_surgery_functions import _bert, _gpt2 # pyright: reportUnusedImport=none
+ from composer.algorithms.alibi.attention_surgery_functions import _bert # pyright: ignore[reportUnusedImport]
+ from composer.algorithms.alibi.attention_surgery_functions import _gpt2 # pyright: ignore[reportUnusedImport]
from composer.algorithms.alibi.attention_surgery_functions.utils import policy_registry
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') from e
diff --git a/composer/algorithms/alibi/attention_surgery_functions/_bert.py b/composer/algorithms/alibi/attention_surgery_functions/_bert.py
index 915e940cad..c2a7bb3bd5 100644
--- a/composer/algorithms/alibi/attention_surgery_functions/_bert.py
+++ b/composer/algorithms/alibi/attention_surgery_functions/_bert.py
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
+import copy
import math
from types import MethodType
from typing import Optional, Tuple
@@ -20,13 +21,14 @@ def bert_embedding_converter(module: torch.nn.Module, module_index: int, max_seq
"""
assert isinstance(module, (BertEmbeddings, RobertaEmbeddings))
del module_index # unused
- zero_and_freeze_expand_position_embeddings(module,
+ new_module = copy.deepcopy(module)
+ zero_and_freeze_expand_position_embeddings(new_module,
max_sequence_length,
position_embedding_attribute='position_embeddings')
- module_device = next(module.parameters()).device
- module.register_buffer('position_ids', torch.arange(max_sequence_length).expand((1, -1)).to(module_device))
- return module
+ module_device = next(new_module.parameters()).device
+ new_module.register_buffer('position_ids', torch.arange(max_sequence_length).expand((1, -1)).to(module_device))
+ return new_module
@policy_registry.register(BertSelfAttention, RobertaSelfAttention)
diff --git a/composer/algorithms/blurpool/README.md b/composer/algorithms/blurpool/README.md
index f99e1fb275..24b25d221a 100644
--- a/composer/algorithms/blurpool/README.md
+++ b/composer/algorithms/blurpool/README.md
@@ -56,9 +56,7 @@ def training_loop(model, train_loader):
-```python
-from composer.models import composer_deeplabv3
-
-model = composer_deeplabv3(num_classes=150,
- backbone_arch="resnet101",
- backbone_weights="IMAGENET1K_V2",
- sync_bn=False
-)
-```
-
-## Architecture
-
-Based on [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)
-
-
-
-
-
-
-- **Backbone network**: converts the input image into a feature map.
- * Usually ResNet-101 with the strided convolutions converted to dilations convolutions in stage 3 and 4.
- * The 3x3 convolutions in stage 3 and 4 have dilation sizes of 2 and 4, respectively, to compensate for the decreased receptive field.
- * The average pooling and classification layer are ignored.
-- **Spatial Pyramid Pooling**: extracts multi-resolution features from the stage 4 backbone feature map.
- * The backbone feature map is processed with four parallel convolution layers with dilations {1, 12, 24, 36} and kernel sizes {1x1, 3x3, 3x3, 3x3}.
- * In parallel to the convolutions, global average pool the backbone feature map, then bilinearly upsample to be the same spatial dimension as the feature map.
- * Concatenate the outputs from the convolutions and global average pool, then process with a 1x1 convolution.
- * The 3x3 convolutions are implemented as depth-wise convolutions to reduce memory and computation cost.
-- **Decoder**: converts the output of spatial pyramid pooling (SPP) to class predictions of the same spatial dimension as the input image.
- * SPP output is bilinearly upsampled to be the same spatial dimension as the output from the first stage in the backbone network.
- * A 1x1 convolution is applied to the first stage activations, then this is concatenated with the upsampled SPP output.
- * The concatenation is processed by a 3x3 convolution with dropout followed by a classification layer.
- * The predictions are bilinearly upsampled to be the same resolution as the input image.
-
-## Training Hyperparameters
-
-We tested two sets of hyperparameters for DeepLabv3+ trained on the ADE20k dataset.
-
-### Typical ADE20k Model Hyperparameters
-
-- Model: deeplabv3:
- - Initializers: kaiming_normal, bn_ones
- - Number of classes: 150
- - Backbone weights: IMAGENET1K_V1
- - Sync BatchNorm
-- Optimizer: SGD
- - Learning rate: 0.01
- - Momentum: 0.9
- - Weight decay: 5.0e-4
- - Dampening: 0
- - Nsterov: false
-- LR schedulers:
- - Polynomial:
- - Alpha_f: 0.01
- - Power: 0.9
-- Number of epochs: 127
-- Batch size: 16
-- Precision: amp
-
-| Model | mIoU | Time-to-Train on 8xA100 |
-| --- | --- | --- |
-| ResNet101-DeepLabv3+ | 44.17 +/- 0.17 | 6.385 hr |
-
-### Composer ADE20k Model Hyperparameters
-
-- Model: deeplabv3:
- - Initializers: kaiming_normal, bn_ones
- - Number of classes: 150
- - Backbone Architecture: resnet101
- - Sync BatchNorm
- - Backbone weights: IMAGENET1K_V2
-- Optimizer: Decoupled SGDW
- - Learning rate: 0.01
- - Momentum: 0.9
- - Weight decay: 2.0e-5
- - Dampening: 0
- - Nesterov: false
-- LR schedulers:
- - Cosine decay, t_max: 1dur
-- Number of epochs: 128
-- Batch size: 32
-- Precision: amp
-
-| Model | mIoU | Time-to-Train on 8xA100 |
-| --- | --- | --- |
-| ResNet101-DeepLabv3+ | 45.764 +/- 0.29 | 4.67 hr |
-
-Improvements:
-
-- New PyTorch pretrained weights
-- Cosine decay
-- Decoupled Weight Decay
-- Increase batch size to 32
-- Decrease weight decay to 2e-5
-
-## Attribution
-
-[Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611) by Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam
-
-[OpenMMLab Semantic Segmentation Toolbox and Benchmark](https://github.com/open-mmlab/mmsegmentation)
-
-[How to Train State-Of-The-Art Models Using TorchVisionâs Latest Primitives](https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/) by Vasilis Vryniotis
-
-## API Reference
-
-```{eval-rst}
-.. autoclass:: composer.models.deeplabv3.composer_deeplabv3
- :noindex:
-```
diff --git a/composer/models/deeplabv3/__init__.py b/composer/models/deeplabv3/__init__.py
deleted file mode 100644
index e3473a3015..0000000000
--- a/composer/models/deeplabv3/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""DeepLabV3 for image segmentation."""
-from composer.models.deeplabv3.model import composer_deeplabv3 as composer_deeplabv3
-
-__all__ = ['composer_deeplabv3']
diff --git a/composer/models/deeplabv3/model.py b/composer/models/deeplabv3/model.py
deleted file mode 100644
index 7e58847708..0000000000
--- a/composer/models/deeplabv3/model.py
+++ /dev/null
@@ -1,256 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""DeepLabV3 model extending :class:`.ComposerClassifier`."""
-
-import functools
-import textwrap
-import warnings
-from typing import Dict, Optional, Sequence
-
-import torch
-import torch.distributed as torch_dist
-import torch.nn.functional as F
-import torchvision
-from packaging import version
-from torchmetrics import MetricCollection
-from torchvision.models import _utils, resnet
-
-from composer.loss import DiceLoss, soft_cross_entropy
-from composer.metrics import CrossEntropy, MIoU
-from composer.models.initializers import Initializer
-from composer.models.tasks import ComposerClassifier
-from composer.utils import dist
-
-__all__ = ['deeplabv3', 'composer_deeplabv3']
-
-
-class SimpleSegmentationModel(torch.nn.Module):
-
- def __init__(self, backbone, classifier):
- warnings.warn(DeprecationWarning('SimpleSegmentationModel is deprecated and will be removed in v0.18'))
-
- super().__init__()
- self.backbone = backbone
- self.classifier = classifier
-
- def forward(self, x):
- input_shape = x.shape[-2:]
- features = self.backbone(x)
- logits = self.classifier(tuple(features.values()))
- logits = F.interpolate(logits,
- size=input_shape,
- mode='bilinear',
- align_corners=False,
- recompute_scale_factor=False)
- return logits
-
-
-def deeplabv3(num_classes: int,
- backbone_arch: str = 'resnet101',
- backbone_weights: Optional[str] = None,
- sync_bn: bool = True,
- use_plus: bool = True,
- initializers: Sequence[Initializer] = ()):
- """Helper function to build a mmsegmentation DeepLabV3 model.
-
- Args:
- num_classes (int): Number of classes in the segmentation task.
- backbone_arch (str, optional): The architecture to use for the backbone. Must be either
- [``'resnet50'``, ``'resnet101'``]. Default: ``'resnet101'``.
- backbone_weights (str, optional): If specified, the PyTorch pre-trained weights to load for the backbone.
- Currently, only ['IMAGENET1K_V1', 'IMAGENET1K_V2'] are supported. Default: ``None``.
- sync_bn (bool, optional): If ``True``, replace all BatchNorm layers with SyncBatchNorm layers.
- Default: ``True``.
- use_plus (bool, optional): If ``True``, use DeepLabv3+ head instead of DeepLabv3. Default: ``True``.
- initializers (Sequence[Initializer], optional): Initializers for the model. ``()`` for no initialization.
- Default: ``()``.
-
- Returns:
- deeplabv3: A DeepLabV3 :class:`torch.nn.Module`.
-
- Example:
-
- .. code-block:: python
-
- from composer.models.deeplabv3.deeplabv3 import deeplabv3
-
- pytorch_model = deeplabv3(num_classes=150, backbone_arch='resnet101', backbone_weights=None)
- """
- warnings.warn(DeprecationWarning('deeplabv3 is deprecated and will be removed in v0.18'))
-
- # check that the specified architecture is in the resnet module
- if not hasattr(resnet, backbone_arch):
- raise ValueError(f'backbone_arch must be part of the torchvision resnet module, got value: {backbone_arch}')
-
- # change the model weight url if specified
- if version.parse(torchvision.__version__) < version.parse('0.13.0'):
- pretrained = False
- if backbone_weights:
- pretrained = True
- if backbone_weights == 'IMAGENET1K_V1':
- resnet.model_urls[backbone_arch] = 'https://download.pytorch.org/models/resnet101-63fe2227.pth'
- elif backbone_weights == 'IMAGENET1K_V2':
- resnet.model_urls[backbone_arch] = 'https://download.pytorch.org/models/resnet101-cd907fc2.pth'
- else:
- ValueError(
- textwrap.dedent(f"""\
- `backbone_weights` must be either "IMAGENET1K_V1" or "IMAGENET1K_V2"
- if torchvision.__version__ < 0.13.0. `backbone_weights` was {backbone_weights}."""))
- backbone = getattr(resnet, backbone_arch)(pretrained=pretrained,
- replace_stride_with_dilation=[False, True, True])
- else:
- backbone = getattr(resnet, backbone_arch)(weights=backbone_weights,
- replace_stride_with_dilation=[False, True, True])
-
- # specify which layers to extract activations from
- return_layers = {'layer1': 'layer1', 'layer4': 'layer4'} if use_plus else {'layer4': 'layer4'}
- backbone = _utils.IntermediateLayerGetter(backbone, return_layers=return_layers)
-
- try:
- from mmseg.models import ASPPHead, DepthwiseSeparableASPPHead
- except ImportError as e:
- raise ImportError(
- textwrap.dedent("""\
- Either mmcv or mmsegmentation is not installed. To install mmcv, please run pip install mmcv-full==1.4.4 -f
- https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html where {cu_version} and
- {torch_version} refer to your CUDA and PyTorch versions, respectively. To install mmsegmentation, please
- run pip install mmsegmentation==0.22.0 on command-line.""")) from e
-
- world_size = dist.get_world_size()
- if sync_bn and world_size == 1:
- warnings.warn('sync_bn was true, but only one process is present for training. sync_bn will be ignored.')
-
- norm_type = 'SyncBN' if sync_bn and world_size > 1 else 'BN'
- norm_cfg = {'type': norm_type, 'requires_grad': True}
- if use_plus:
- # mmseg config:
- # https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/models/deeplabv3plus_r50-d8.py
- head = DepthwiseSeparableASPPHead(in_channels=2048,
- in_index=-1,
- channels=512,
- dilations=(1, 12, 24, 36),
- c1_in_channels=256,
- c1_channels=48,
- dropout_ratio=0.1,
- num_classes=num_classes,
- norm_cfg=norm_cfg,
- align_corners=False)
- else:
- # mmseg config:
- # https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/models/deeplabv3_r50-d8.py
- head = ASPPHead(in_channels=2048,
- in_index=-1,
- channels=512,
- dilations=(1, 12, 24, 36),
- dropout_ratio=0.1,
- num_classes=num_classes,
- norm_cfg=norm_cfg,
- align_corners=False)
-
- model = SimpleSegmentationModel(backbone, head)
-
- if initializers:
- for initializer in initializers:
- initializer_fn = Initializer(initializer).get_initializer()
-
- # Only apply initialization to classifier head if pre-trained weights are used
- if backbone_weights is None:
- model.apply(initializer_fn)
- else:
- model.classifier.apply(initializer_fn)
-
- if sync_bn and world_size > 1:
- local_world_size = dist.get_local_world_size()
-
- # List of ranks for each node, assumes that each node has the same number of ranks
- num_nodes = world_size // local_world_size
- process_group = None
- if num_nodes > 1:
- ranks_per_node = [
- list(range(node * local_world_size, (node + 1) * local_world_size)) for node in range(num_nodes)
- ]
- process_groups = [torch_dist.new_group(ranks) for ranks in ranks_per_node]
- process_group = process_groups[dist.get_node_rank()]
-
- model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=process_group)
-
- return model
-
-
-def composer_deeplabv3(num_classes: int,
- backbone_arch: str = 'resnet101',
- backbone_weights: Optional[str] = None,
- sync_bn: bool = True,
- use_plus: bool = True,
- ignore_index: int = -1,
- cross_entropy_weight: float = 1.0,
- dice_weight: float = 0.0,
- initializers: Sequence[Initializer] = ()):
- """Helper function to create a :class:`.ComposerClassifier` with a DeepLabv3(+) model. Logs
- Mean Intersection over Union (MIoU) and Cross Entropy during training and validation.
-
- From `Rethinking Atrous Convolution for Semantic Image Segmentation `_
- (Chen et al, 2017).
-
- Args:
- num_classes (int): Number of classes in the segmentation task.
- backbone_arch (str, optional): The architecture to use for the backbone. Must be either
- [``'resnet50'``, ``'resnet101'``]. Default: ``'resnet101'``.
- backbone_weights (str, optional): If specified, the PyTorch pre-trained weights to load for the backbone.
- Currently, only ['IMAGENET1K_V1', 'IMAGENET1K_V2'] are supported. Default: ``None``.
- sync_bn (bool, optional): If ``True``, replace all BatchNorm layers with SyncBatchNorm layers.
- Default: ``True``.
- use_plus (bool, optional): If ``True``, use DeepLabv3+ head instead of DeepLabv3. Default: ``True``.
- ignore_index (int): Class label to ignore when calculating the loss and other metrics. Default: ``-1``.
- cross_entropy_weight (float): Weight to scale the cross entropy loss. Default: ``1.0``.
- dice_weight (float): Weight to scale the dice loss. Default: ``0.0``.
- initializers (List[Initializer], optional): Initializers for the model. ``[]`` for no initialization.
- Default: ``[]``.
-
-
- Returns:
- ComposerModel: instance of :class:`.ComposerClassifier` with a DeepLabv3(+) model.
-
- Example:
-
- .. code-block:: python
-
- from composer.models import composer_deeplabv3
-
- model = composer_deeplabv3(num_classes=150, backbone_arch='resnet101', backbone_weights=None)
- """
- warnings.warn(DeprecationWarning('composer_deeplabv3 is deprecated and will be removed in v0.18'))
-
- model = deeplabv3(backbone_arch=backbone_arch,
- backbone_weights=backbone_weights,
- use_plus=use_plus,
- num_classes=num_classes,
- sync_bn=sync_bn,
- initializers=initializers)
-
- train_metrics = MetricCollection(
- [CrossEntropy(ignore_index=ignore_index),
- MIoU(num_classes, ignore_index=ignore_index)])
- val_metrics = MetricCollection(
- [CrossEntropy(ignore_index=ignore_index),
- MIoU(num_classes, ignore_index=ignore_index)])
-
- ce_loss_fn = functools.partial(soft_cross_entropy, ignore_index=ignore_index)
- dice_loss_fn = DiceLoss(softmax=True, batch=True, ignore_absent_classes=True)
-
- def _combo_loss(output, target) -> Dict[str, torch.Tensor]:
- loss = {'total': torch.zeros(1, device=output.device, dtype=output.dtype)}
- if cross_entropy_weight:
- loss['cross_entropy'] = ce_loss_fn(output, target)
- loss['total'] += loss['cross_entropy'] * cross_entropy_weight
- if dice_weight:
- loss['dice'] = dice_loss_fn(output, target)
- loss['total'] += loss['dice'] * dice_weight
- return loss
-
- composer_model = ComposerClassifier(module=model,
- train_metrics=train_metrics,
- val_metrics=val_metrics,
- loss_fn=_combo_loss)
- return composer_model
diff --git a/composer/models/efficientnetb0/README.md b/composer/models/efficientnetb0/README.md
deleted file mode 100644
index 9cb1096bc6..0000000000
--- a/composer/models/efficientnetb0/README.md
+++ /dev/null
@@ -1,78 +0,0 @@
-# EfficientNet
-[\[Example\]](#example) · [\[Architecture\]](#architecture) · [\[Family Members\]](#family-members) · [\[Default Training Hyperparameters\]](#default-training-hyperparameters) · [\[Attribution\]](#attribution) · [\[API Reference\]](#api-reference)
-
-`Vision` /`Image Classification`
-
-The EfficientNet model family is a set of convolutional neural networks that can be used as the basis for a variety of vision tasks, but were initially designed for image classification. The model family was designed to reach the highest accuracy for a given computation budget during inference by simultaneously scaling model depth, model width, and image resolution according to an empirically determined scaling law.
-
-## Example
-
-```python
-from composer.models import composer_efficientnetb0
-
-model = composer_efficientnetb0(num_classes=1000, drop_connect_rate=0.2)
-```
-
-## Architecture
-
-The table below from Tan and Le specifies the EfficientNet baseline architecture broken up into separate stages. MBConv indicates a mobile inverted bottleneck with a specific expansion size and kernel size. Resolution is the expected input resolution of the current stage. Number of channels is the number of output channels of the current stage. Number of layers indicates the number of repeated blocks in each stage. Subsequent EfficientNet family members scale the resolution, number of channels, and number of layers according to the resolution, width, and depth scaling parameters defined by Tan and Le.
-
-![efficientnet_arch.png](https://storage.googleapis.com/docs.mosaicml.com/images/models/efficientnet_arch.png)
-
-## Family members
-
-Tan and Le included 8 members in their model family. The goal was for each family member to have approximately double the FLOPs of the previous family member. Currently, we only support EfficientNet-B0.
-
-| Model Family Member | Parameter Count | TPU Repo Accuracy* | Our Accuracy** | Training Time on 8x3080 |
-|---------------------|-----------------|--------------------|----------------|-------------------------|
-| EfficientNet-B0 | 5.3M | 77.1% | 77.22% | 23.3 hr |
-| EfficientNet-B1 | 7.8M | 79.1% | TBA | TBA |
-| EfficientNet-B2 | 9.2M | 80.1% | TBA | TBA |
-| EfficientNet-B3 | 12M | 81.6% | TBA | TBA |
-| EfficientNet-B4 | 19M | 82.9% | TBA | TBA |
-| EfficientNet-B5 | 30M | 83.6% | TBA | TBA |
-| EfficientNet-B6 | 43M | 84.0% | TBA | TBA |
-| EfficientNet-B7 | 66M | 84.3% | TBA | TBA |
-
-*Includes label smoothing, sample-wise stochastic depth, and AutoAugment
-
-**Includes label smoothing and sample-wise stochastic depth
-
-## Default Training Hyperparameters
-
-We use the following default hyperparameters from the [Nvidia Deep Learning Examples](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/efficientnet):
-
-```yaml
-optimizer:
- rmsprop:
- lr: 0.08
- momentum: 0.9
- alpha: 0.9
- eps: 0.01
- weight_decay: 1.0e-5
-schedulers:
- - cosine_decay_with_warmup:
- t_warmup: "16ep"
-train_batch_size: 4096
-max_duration: 400ep
-```
-
-Our implementation differs from the [Nvidia Deep Learning Examples](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/efficientnet) in that we:
-
-- Apply weight decay to batch normalization trainable parameters
-- Use `momentum = 0.1` and `eps = 1e-5` as batch normalization parameters
-
-## Attribution
-
-Paper: [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) by Mingxing Tan and Quoc V. Le
-
-Code: [gen-efficientnet-pytorch Github repository](https://github.com/rwightman/gen-efficientnet-pytorch) by Ross Wightman
-
-Hyperparameters: [DeepLearningExamples Github repository](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/efficientnet) by Nvidia
-
-## API Reference
-
-```{eval-rst}
-.. autoclass:: composer.models.efficientnetb0.composer_efficientnetb0
- :noindex:
-```
diff --git a/composer/models/efficientnetb0/__init__.py b/composer/models/efficientnetb0/__init__.py
deleted file mode 100644
index d1101f595c..0000000000
--- a/composer/models/efficientnetb0/__init__.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""The EfficientNet model family is a set of convolutional neural networks that can be used as the basis for a variety
-of vision tasks, but were initially designed for image classification. The model family was designed to reach the
-highest accuracy for a given computation budget during inference by simultaneously scaling model depth, model width, and
-image resolution according to an empirically determined scaling law.
-
-See the :doc:`Model Card ` for more details.
-"""
-from composer.models.efficientnetb0.model import composer_efficientnetb0 as composer_efficientnetb0
-
-__all__ = ['composer_efficientnetb0']
-
-_task = 'Image Classification'
-_dataset = 'ImageNet'
-_name = 'EfficientNet-B0'
-_quality = '76.63'
-_metric = 'Top-1 Accuracy'
-_ttt = '21h 48m'
-_hparams = 'efficientnetb0.yaml'
diff --git a/composer/models/efficientnetb0/_layers.py b/composer/models/efficientnetb0/_layers.py
deleted file mode 100644
index ab12aec9c3..0000000000
--- a/composer/models/efficientnetb0/_layers.py
+++ /dev/null
@@ -1,263 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-from typing import Callable, Optional
-
-import torch
-from torch import nn as nn
-
-
-def round_channels(
- channels: float,
- width_multiplier: float,
- divisor: int = 8,
- min_value: Optional[int] = None,
-) -> int:
- """Round number of channels after scaling with width multiplier.
-
- This function ensures that channel integers halfway in-between divisors is rounded up.
-
- Args:
- channels (float): Number to round.
- width_multiplier (float): Amount to scale `channels`.
- divisor (int): Number to make the output divisible by.
- min_value (int, optional): Minimum value the output can be. If not specified, defaults
- to the ``divisor``.
- """
- if not width_multiplier:
- return int(channels)
- channels *= width_multiplier
-
- min_value = min_value or divisor
- new_channels = max(min_value, int(channels + divisor / 2) // divisor * divisor)
- if new_channels < 0.9 * channels: # increase channels if rounding decreases by >10%
- new_channels += divisor
- return new_channels
-
-
-def calculate_same_padding(kernel_size, dilation, stride):
- """Calculates the amount of padding to use to get the "SAME" functionality in Tensorflow."""
- return ((stride - 1) + dilation * (kernel_size - 1)) // 2
-
-
-def drop_connect(inputs: torch.Tensor, drop_connect_rate: float, training: bool):
- """Randomly mask a set of samples. Provides similar regularization as stochastic depth.
-
- Args:
- input (torch.Tensor): Input tensor to mask.
- drop_connect_rate (float): Probability of droppping each sample.
- training (bool): Whether or not the model is training
- """
- if not training:
- return inputs
-
- keep_prob = 1 - drop_connect_rate
- rand_tensor = keep_prob + torch.rand(
- [inputs.size()[0], 1, 1, 1],
- dtype=inputs.dtype,
- device=inputs.device,
- )
- rand_tensor.floor_() # binarize
- output = inputs.div(keep_prob) * rand_tensor
- return output
-
-
-class SqueezeExcite(nn.Module):
- """Squeeze Excite Layer.
-
- Args:
- in_channels (int): Number of channels in the input tensor.
- latent_channels (int): Number of hidden channels.
- act_layer (torch.nn.Module): Activation layer to use in block.
- """
-
- def __init__(
- self,
- in_channels: int,
- latent_channels: int,
- act_layer: Callable[..., nn.Module] = nn.ReLU,
- ):
- super().__init__()
-
- self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
- self.conv_reduce = nn.Conv2d(in_channels, latent_channels, kernel_size=1, bias=True)
- self.act1 = act_layer(inplace=True)
- self.conv_expand = nn.Conv2d(latent_channels, in_channels, kernel_size=1, bias=True)
- self.gate_fn = torch.nn.Sigmoid()
-
- def forward(self, x: torch.Tensor):
- out = self.global_avg_pool(x)
- out = self.conv_reduce(out)
- out = self.act1(out)
- out = self.conv_expand(out)
- out = x * self.gate_fn(out)
- return out
-
-
-class DepthwiseSeparableConv(nn.Module):
- """Depthwise Separable Convolution layer.
-
- Args:
- in_channels (int): Number of channels in the input tensor.
- out_channels (int): Number of channels in the output tensor.
- kernel_size (int): Size of the convolving kernel.
- stride (int): Stride of the convolution.
- se_ratio (float): How much to scale `in_channels` for the hidden layer
- dimensionality of the squeeze-excite module.
- drop_connect_rate (float): Probability of dropping a sample before the
- identity connection, provides regularization similar to stochastic
- depth.
- act_layer (torch.nn.Module): Activation layer to use in block.
- norm_kwargs (dict): Normalization layer's keyword arguments.
- norm_layer (torch.nn.Module): Normalization layer to use in block.
- """
-
- def __init__(self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: int,
- se_ratio: float,
- drop_connect_rate: float,
- act_layer: Callable[..., nn.Module],
- norm_kwargs: dict,
- norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d):
- super().__init__()
- self.drop_connect_rate = drop_connect_rate
- self.has_residual = (in_channels == out_channels and stride == 1)
- self.has_se = se_ratio is not None and se_ratio > 0.0
-
- padding = calculate_same_padding(kernel_size, dilation=1, stride=stride)
- self.conv_depthwise = nn.Conv2d(in_channels=in_channels,
- out_channels=in_channels,
- groups=in_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- bias=False)
- self.bn1 = norm_layer(in_channels, **norm_kwargs)
- self.act1 = act_layer(inplace=True)
-
- if self.has_se:
- latent_channels = max(1, int(in_channels * se_ratio))
- self.se = SqueezeExcite(in_channels, latent_channels, act_layer)
-
- self.conv_pointwise = nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=1,
- bias=False,
- )
- self.bn2 = norm_layer(out_channels, **norm_kwargs)
- self.act2 = act_layer(inplace=True)
-
- def forward(self, input: torch.Tensor):
- residual = input
-
- out = self.conv_depthwise(input)
- out = self.bn1(out)
- out = self.act1(out)
-
- if self.has_se:
- out = self.se(out)
-
- out = self.conv_pointwise(out)
- out = self.bn2(out)
- out = self.act2(out)
-
- if self.has_residual:
- if self.drop_connect_rate > 0.0:
- out = drop_connect(out, self.drop_connect_rate, self.training)
- out += residual
- return out
-
-
-class MBConvBlock(nn.Module):
- """Mobile Inverted Residual Bottleneck Block.
-
- This block is implemented as as defined in
- `MobileNetV2: Inverted Residuals and Linear Bottlenecks `_ (Sandler et al, 2018).
-
- Args:
- in_channels (int): Number of channels in the input tensor.
- out_channels (int): Number of channels in the output tensor.
- kernel_size (int): Size of the convolving kernel.
- stride (int): Stride of the convolution.
- expand_ratio (int): How much to expand the input channels for the
- depthwise convolution.
- se_ratio (float): How much to scale `in_channels` for the hidden layer
- dimensionality of the squeeze-excite module.
- drop_connect_rate (float): Probability of dropping a sample before the
- identity connection, provides regularization similar to stochastic
- depth.
- act_layer (torch.nn.Module): Activation layer to use in block.
- norm_kwargs (dict): Normalization layer's keyword arguments.
- norm_layer (torch.nn.Module): Normalization layer to use in block.
- """
-
- def __init__(self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: int,
- expand_ratio: int,
- se_ratio: float,
- drop_connect_rate: float,
- act_layer: Callable[..., nn.Module],
- norm_kwargs: dict,
- norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d):
- super().__init__()
- self.drop_connect_rate = drop_connect_rate
- self.has_residual = (in_channels == out_channels and stride == 1)
- self.has_se = se_ratio is not None and se_ratio > 0.0
-
- mid_channels = round_channels(in_channels, expand_ratio)
-
- # Point-wise convolution expansion
- self.conv1x1_expand = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
- self.bn1 = norm_layer(mid_channels, **norm_kwargs)
- self.act1 = act_layer(inplace=True)
-
- # Depth-wise Convolution
- padding = calculate_same_padding(kernel_size, dilation=1, stride=stride)
- self.conv_depthwise = nn.Conv2d(in_channels=mid_channels,
- out_channels=mid_channels,
- groups=mid_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- bias=False)
- self.bn2 = norm_layer(mid_channels, **norm_kwargs)
- self.act2 = act_layer(inplace=True)
-
- # Squeeze and Excitation layer, if specified
- if self.has_se:
- latent_channels = max(1, int(in_channels * se_ratio))
- self.se = SqueezeExcite(mid_channels, latent_channels, act_layer)
-
- # Point-wise convolution contraction
- self.conv1x1_contract = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
- self.bn3 = norm_layer(out_channels, **norm_kwargs)
-
- def forward(self, input: torch.Tensor):
- residual = input
-
- out = self.conv1x1_expand(input)
- out = self.bn1(out)
- out = self.act1(out)
-
- out = self.conv_depthwise(out)
- out = self.bn2(out)
- out = self.act2(out)
-
- if self.has_se:
- out = self.se(out)
-
- out = self.conv1x1_contract(out)
- out = self.bn3(out)
-
- if self.has_residual:
- if self.drop_connect_rate:
- out = drop_connect(out, self.drop_connect_rate, self.training)
- out += residual
- return out
diff --git a/composer/models/efficientnetb0/efficientnets.py b/composer/models/efficientnetb0/efficientnets.py
deleted file mode 100644
index 7c544a5143..0000000000
--- a/composer/models/efficientnetb0/efficientnets.py
+++ /dev/null
@@ -1,229 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""EfficientNet model.
-
-Adapted from `(Generic) EfficientNets for PyTorch. `_.
-"""
-
-import math
-import re
-import warnings
-from typing import Callable, Optional
-
-import torch
-import torch.nn as nn
-
-from composer.models.efficientnetb0._layers import (DepthwiseSeparableConv, MBConvBlock, calculate_same_padding,
- round_channels)
-
-__all__ = ['EfficientNet']
-
-
-class EfficientNet(nn.Module):
- """EfficientNet model based on (`Tan et al, 2019 `_).
-
- Args:
- num_classes (int): Size of the EfficientNet output, typically viewed
- as the number of classes in a classification task.
- width_multiplier (float, optional): How much to scale the EfficientNet-B0 channel
- dimension throughout the model. Default: ``1.0``.
- depth_multiplier (float, optional): How much to scale the EFficientNet-B0 depth. Default: ``1.0``.
- drop_rate (float, optional): Dropout probability for the penultimate activations. Default: ``0.2``.
- drop_connect_rate (float, optional): Probability of dropping a sample before the
- identity connection, provides regularization similar to stochastic
- depth. Default: ``0.2``.
- act_layer (torch.nn.Module, optional): Activation layer to use in the model. Default: ``nn.SiLU``.
- norm_kwargs (dict, optional): Normalization layer's keyword arguments. Default: ``{"momentum": 0.1, "eps": 1e-5}``.
- norm_layer (torch.nn.Module, optional): Normalization layer to use in the model. Default: ``nn.BatchNorm2d``.
- """
-
- # EfficientNet-B0 architecture specification.
- # block_strings are decoded into block level hyperparameters.
- # r=repeat, k=kernel_size, s=stride, e=expand_ratio, i=in_channels, o=out_channels, se=se_ratio.
- _blocks_strings = [
- 'r1_k3_s1_e1_i32_o16_se0.25',
- 'r2_k3_s2_e6_i16_o24_se0.25',
- 'r2_k5_s2_e6_i24_o40_se0.25',
- 'r3_k3_s2_e6_i40_o80_se0.25',
- 'r3_k5_s1_e6_i80_o112_se0.25',
- 'r4_k5_s2_e6_i112_o192_se0.25',
- 'r1_k3_s1_e6_i192_o320_se0.25',
- ]
-
- def __init__(self,
- num_classes: int,
- width_multiplier: float = 1.0,
- depth_multiplier: float = 1.0,
- drop_rate: float = 0.2,
- drop_connect_rate: float = 0.2,
- act_layer: Callable[..., nn.Module] = nn.SiLU,
- norm_kwargs: Optional[dict] = None,
- norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d):
- warnings.warn(DeprecationWarning('EfficientNet is deprecated and will be removed in v0.18'))
-
- super(EfficientNet, self).__init__()
- self.num_classes = num_classes
-
- if norm_kwargs is None:
- norm_kwargs = {'momentum': 0.1, 'eps': 1e-5}
-
- in_channels = 3
- out_channels = round_channels(32, width_multiplier)
- padding = calculate_same_padding(kernel_size=3, dilation=1, stride=2)
- self.conv_stem = nn.Conv2d(
- in_channels,
- out_channels,
- kernel_size=3,
- stride=2,
- padding=padding,
- bias=False,
- )
- self.bn1 = norm_layer(num_features=out_channels, **norm_kwargs)
- self.act1 = act_layer(inplace=True)
-
- # Count the number of blocks in the model
- block_count = 0.
- for block_string in self._blocks_strings:
- _, num_repeat = self._decode_block_string(block_string)
- block_count += num_repeat
-
- # Decode block strings and add blocks
- block_idx = 0.
- blocks = []
- block_args = {}
- for block_string in self._blocks_strings:
- block_args, num_repeat = self._decode_block_string(block_string)
- # Scale channels and number of repeated blocks based on multipliers
- block_args['in_channels'] = round_channels(
- block_args['in_channels'],
- width_multiplier,
- )
- block_args['out_channels'] = round_channels(
- block_args['out_channels'],
- width_multiplier,
- )
- num_repeat = int(math.ceil(depth_multiplier * num_repeat))
-
- # Add activation, normalization layers, and drop connect
- block_args['act_layer'] = act_layer
- block_args['norm_kwargs'] = norm_kwargs
- block_args['norm_layer'] = norm_layer
-
- # Delete expand_ratio when set to 1 to use depthwise separable convolution layer
- if block_args['expand_ratio'] == 1:
- del block_args['expand_ratio']
-
- for i in range(num_repeat):
- # Linearly decay drop_connect_rate across model depth
- block_args['drop_connect_rate'] = drop_connect_rate * block_idx / block_count
-
- if 'expand_ratio' not in block_args:
- blocks.append(DepthwiseSeparableConv(**block_args))
- else:
- blocks.append(MBConvBlock(**block_args))
- block_idx += 1
-
- # Only the first block in a stage can have stride != 1
- if i == 0:
- block_args['stride'] = 1
- block_args['in_channels'] = block_args['out_channels']
-
- self.blocks = nn.Sequential(*blocks)
-
- in_channels = block_args['out_channels']
- out_channels = round_channels(1280, width_multiplier)
- self.conv_head = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
- self.bn2 = norm_layer(out_channels, **norm_kwargs)
- self.act2 = act_layer(inplace=True)
-
- self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
- self.dropout = nn.Dropout(drop_rate)
- self.classifier = nn.Linear(out_channels, num_classes)
-
- # Initialization from gen-efficientnet-pytorch repo
- for m in self.modules():
- if isinstance(m, torch.nn.Conv2d):
- fan_out = (m.kernel_size[0] * m.kernel_size[1] * m.out_channels) // m.groups
- m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, torch.nn.BatchNorm2d):
- m.weight.data.fill_(1.0)
- m.bias.data.zero_()
- elif isinstance(m, torch.nn.Linear):
- fan_out = m.weight.size(0)
- init_range = 1.0 / math.sqrt(fan_out)
- m.weight.data.uniform_(-init_range, init_range)
- m.bias.data.zero_()
-
- def extract_features(self, input: torch.Tensor):
- out = self.conv_stem(input)
- out = self.bn1(out)
- out = self.act1(out)
- out = self.blocks(out)
- out = self.conv_head(out)
- out = self.bn2(out)
- out = self.act2(out)
- out = self.global_avg_pool(out)
- return out.flatten(1)
-
- def forward(self, input: torch.Tensor):
- out = self.extract_features(input)
- out = self.dropout(out)
- return self.classifier(out)
-
- @staticmethod
- def get_model_from_name(model_name: str, num_classes, drop_connect_rate: float):
- """Instantiate an EfficientNet model family member based on the model_name string.
-
- Args:
- model_name: (str): One of ``'efficientnet-b0'`` through ``'efficientnet-b7'``.
- num_classes (int): Size of the EfficientNet output, typically viewed as the number of classes in a classification task.
- drop_connect_rate (float): Probability of dropping a sample before the identity connection,
- provides regularization similar to stochastic depth.
- """
-
- # Coefficients: width, depth, res, dropout
- model_arch = {
- 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
- 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
- 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
- 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
- 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
- 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
- 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
- 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
- }
-
- model_params = model_arch[model_name]
- width_multiplier = model_params[0]
- depth_multiplier = model_params[1]
- drop_rate = model_params[3]
- return EfficientNet(num_classes=num_classes,
- width_multiplier=width_multiplier,
- depth_multiplier=depth_multiplier,
- drop_rate=drop_rate,
- drop_connect_rate=drop_connect_rate)
-
- def _decode_block_string(self, block_string: str):
- """Decodes an EfficientNet block specification string into a dictionary of keyword arguments for a block in the
- architecture."""
-
- arg_strings = block_string.split('_')
- args = {}
- for arg_string in arg_strings:
- splits = re.split(r'(\d.*)', arg_string)
- if len(splits) >= 2:
- key, value = splits[:2]
- args[key] = value
- num_repeat = int(args['r'])
- block_args = {
- 'kernel_size': int(args['k']),
- 'stride': int(args['s']),
- 'expand_ratio': int(args['e']),
- 'in_channels': int(args['i']),
- 'out_channels': int(args['o']),
- 'se_ratio': float(args['se']) if 'se' in args else None,
- }
- return block_args, num_repeat
diff --git a/composer/models/efficientnetb0/model.py b/composer/models/efficientnetb0/model.py
deleted file mode 100644
index 67ae193895..0000000000
--- a/composer/models/efficientnetb0/model.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""A :class:`.ComposerClassifier` wrapper around the EfficientNet-b0 architecture."""
-
-import warnings
-
-from composer.models.efficientnetb0.efficientnets import EfficientNet
-from composer.models.tasks import ComposerClassifier
-
-__all__ = ['composer_efficientnetb0']
-
-
-def composer_efficientnetb0(num_classes: int = 1000, drop_connect_rate: float = 0.2) -> ComposerClassifier:
- """Helper function to create a :class:`.ComposerClassifier` with an EfficientNet-b0 architecture.
-
- See `Rethinking Model Scaling for Convolutional Neural Networks `_
- (Tan et al, 2019) for more details.
-
- Args:
- num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``1000``.
- drop_connect_rate (float, optional): Probability of dropping a sample within a block before identity
- connection. Default: ``0.2``.
-
- Returns:
- ComposerModel: instance of :class:`.ComposerClassifier` with a EfficientNet-B0 model.
-
-
- Example:
-
- .. testcode::
-
- from composer.models import composer_efficientnetb0
-
- model = composer_efficientnetb0() # creates EfficientNet-b0 for image classification
- """
- warnings.warn(DeprecationWarning('composer_efficientnetb0 is deprecated and will be removed in v0.18'))
- model = EfficientNet.get_model_from_name(model_name='efficientnet-b0',
- num_classes=num_classes,
- drop_connect_rate=drop_connect_rate)
-
- composer_model = ComposerClassifier(module=model)
- return composer_model
diff --git a/composer/models/gpt2/README.md b/composer/models/gpt2/README.md
deleted file mode 100644
index 52ee26a97f..0000000000
--- a/composer/models/gpt2/README.md
+++ /dev/null
@@ -1,81 +0,0 @@
-# GPT-2
-[\[Example\]](#example) · [\[Architecture\]](#architecture) · [\[Family Members\]](#family-members) · [\[Default Training Hyperparameters\]](#default-training-hyperparameters) · [\[Attribution\]](#attribution) · [\[API Reference\]](#api-reference)
-
-`NLP` / ``Autoregressive Language Modeling``
-
-The GPT-2 model family is set of transformer-based networks for autoregressive language modeling at various scales. This family was originally proposed by OpenAI, and is trained on the OpenWebText dataset. It is useful for downstream language generation tasks, such as summarization, translation, and dialog.
-
-Our codebase builds off of the Hugging Face *[Transformers](https://huggingface.co/transformers/)* library. We initialize Huggingface's GPT-2 model with one of our configurations.
-
-## Example
-
-
-
-```python
-import transformers
-from composer.models import GPT2Model
-
-model = GPT2Model(module=transformers.AutoModelForCausalLM.from_pretrained("gpt2"),
- config=transformers.GPT2Config.from_pretrained("gpt2"),
- tokenizer_name="gpt2")
-```
-
-## Architecture
-
-GPT-2 consists of a a decoder-only Transformer parameterized by $n_{layer}$, $d_{model}$, $d_{ff}$, $d_{attn}$ and $n_{heads}$. The parameters for each model family member can be seen below:
-
-| Name | $n_{layer}$ | $d_{model}$ | $d_{ff}$ | $d_{attn}$ | $n_{heads}$ |
-|------------|-------------|-------------|----------|------------|-------------|
-| GPT-2 52M | 8 | 512 | 2048 | 8 | 8 |
-| GPT-2 83M | 10 | 640 | 2560 | 640 | 10 |
-| GPT-2 125M | 12 | 768 | 3072 | 768 | 12 |
-
-## Family Members
-
-We implement three members of this family at different scales: GPT 52M, GPT 83M, and GPT 125M. These models are named after their parameter counts. We selected these particular configurations because (1) they represent points on the pareto frontier of the scaling law for language models as described by [Kaplan et al. at OpenAI](https://arxiv.org/abs/2001.08361) and (2) they are small enough to rapidly iterate on methods using a single GPU node.
-
-| Model Family Member | Parameters | Training Hours on 8xA100s | Training Tokens | Final Loss | Predicted Perplexity | Actual Perplexity |
-|---------------------|------------|---------------------------|-----------------|------------|----------------------|-------------------|
-| GPT-2 52M | 53.9M | 02:44 | 4.6B | 3.43 | 32.54 | 30.88 |
-| GPT-2 83M | 85.8M | 04:52 | 5.5B | 3.28 | 27.84 | 26.57 |
-| GPT-2 125M | 114M | 08:25 | 6.7B | 3.18 | 24.64 | 24.04 |
-
-
-There are two ways of varying the amount of time necessary to train a model or the cost necessary to do so: varying the size of the model or varying the number of steps (and therefore data) for which the model is trained. With the GPT family of models, we explore both of these axes. To develop methods for these models, we generally begin with the smallest members of this model family for initial experimentation and scale up once the ideas have been refined.
-
-To explore tradeoffs between quality and number of training steps: we have ablated both number of training steps, and number of data points to train on. We do this by checkpointing the model throughout training.
-
-To explore tradeoffs between quality and the size of the model, we use [Scaling Laws for Neural Language Models](https://arxiv.org/abs/2001.08361) to provide suggestions on model capacity and dataset size, and then sweep hyperparameters such as learning rate and batch size to minimize loss.
-
-
-## Attribution
-
-The GPT model family is described in *[Language Models are Unsupervised Multitask Learners](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)* by Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever.
-
-The Scaling Law that we use to choose the members of this model family are described in *[Scaling Laws for Neural Language Models](https://arxiv.org/abs/2001.08361)* by Jared Kaplan, Sam McCandish, Tom Henighan, Tom B. Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei.
-
-## Default Training Hyperparameters
-
-Below are hyperparameters we used to train GPT-2 125M on [OpenWebText](https://huggingface.co/datasets/openwebtext).
-
-```yaml
-optimizer:
- adamw:
- lr: 6.0e-4
- betas:
- - 0.9
- - 0.999
- eps: 1.0e-08
- weight_decay: 0.0
-schedulers:
- - cosine_decay_with_warmup:
- t_warmup: 140ba
-train_batch_size: 512
-```
-
-## API Reference
-
-```{eval-rst}
-.. autoclass:: composer.models.gpt2.GPT2Model
- :noindex:
-```
diff --git a/composer/models/gpt2/__init__.py b/composer/models/gpt2/__init__.py
deleted file mode 100644
index 1ae37b122a..0000000000
--- a/composer/models/gpt2/__init__.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""The GPT-2 model family is set of transformer-based networks for autoregressive language modeling at various scales.
-This family was originally proposed by OpenAI, and is trained on the OpenWebText dataset. It is useful for downstream
-language generation tasks, such as summarization, translation, and dialog.
-
-See the :doc:`Model Card ` for more details.
-"""
-
-from composer.models.gpt2.model import create_gpt2 as create_gpt2
-
-__all__ = ['create_gpt2']
-
-_metadata = {
- 'gpt2': {
- '_task': 'Language Modeling',
- '_dataset': 'OpenWebText',
- '_name': 'GPT-2 52M',
- '_quality': '30.88',
- '_metric': 'Perplexity',
- '_ttt': '02:44',
- '_hparams': 'gpt2_52m.yaml'
- },
- 'gpt2 -- TODO RENAME TO GPT2': {
- '_task': 'Language Modeling',
- '_dataset': 'OpenWebText',
- '_name': 'GPT-2 83M',
- '_quality': '26.57',
- '_metric': 'Perplexity',
- '_ttt': '04:52',
- '_hparams': 'gpt2_83m.yaml'
- },
- 'gpt2 --! TODO RENAME TO GPT2': {
- '_task': 'Language Modeling',
- '_dataset': 'OpenWebText',
- '_name': 'GPT-2 125M',
- '_quality': '24.04',
- '_metric': 'Perplexity',
- '_ttt': '08:25',
- '_hparams': 'gpt2_125m.yaml'
- }
-}
diff --git a/composer/models/gpt2/model.py b/composer/models/gpt2/model.py
deleted file mode 100644
index ea924b7b99..0000000000
--- a/composer/models/gpt2/model.py
+++ /dev/null
@@ -1,121 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""GPT-2 model based on `Hugging Face GPT-2 `_.
-
-Implemented as a wrapper using :class:`.ComposerTrainer`.
-"""
-
-from __future__ import annotations
-
-import warnings
-from typing import Optional
-
-from composer.metrics.nlp import LanguageCrossEntropy, LanguagePerplexity
-from composer.models.huggingface import HuggingFaceModel
-from composer.utils.import_helpers import MissingConditionalImportError
-
-__all__ = ['create_gpt2']
-
-
-def create_gpt2(use_pretrained: Optional[bool] = False,
- pretrained_model_name: Optional[str] = None,
- model_config: Optional[dict] = None,
- tokenizer_name: Optional[str] = None,
- gradient_checkpointing: Optional[bool] = False):
- """Implements :class:`~composer.models.huggingface.HuggingFaceModel` to wrap `Hugging Face GPT-2 \
- transformers `_. Logs training and
- validation perplexity.
-
- From `Language Models are Unsupervised Multitask Learners `_ (Radford et al, 2018).
-
- Args:
-
- gradient_checkpointing (bool, optional): Use gradient checkpointing. Default: ``False``.
- use_pretrained (bool, optional): Whether to initialize the model with the pretrained weights. Default: ``False``.
- model_config (dict): A dictionary providing a HuggingFace model configuration.
- tokenizer_name (str, optional): Tokenizer name used to preprocess the dataset
- and validate the models inputs.
-
- .. code-block::
-
- {
- "_name_or_path": "gpt2",
- "activation_function": "gelu_new",
- "architectures": ["GPT2LMHeadModel"],
- "attn_pdrop": 0.1,
- "bos_token_id": 50256,
- "embd_pdrop": 0.1,
- "eos_token_id": 50256,
- "initializer_range": 0.02,
- "layer_norm_epsilon": 1e-05,
- "model_type": "gpt2",
- "n_ctx": 1024,
- "n_embd": 768,
- "n_head": 12,
- "n_inner": null,
- "n_layer": 12,
- "n_positions": 1024,
- "reorder_and_upcast_attn": false,
- "resid_pdrop": 0.1,
- "scale_attn_by_inverse_layer_idx": false,
- "scale_attn_weights": true,
- "summary_activation": null,
- "summary_first_dropout": 0.1,
- "summary_proj_to_labels": true,
- "summary_type": "cls_index",
- "summary_use_proj": true,
- "task_specific_params": {
- "text-generation": {
- "do_sample": true,
- "max_length": 50 }
- },
- "transformers_version": "4.16.0",
- "use_cache": true,
- "vocab_size": 50257
- }
-
- To create a GPT-2 model for language modeling pretraining:
-
- .. testcode::
-
- from composer.models import create_gpt2
-
- composer_model = create_gpt2()
-
- """
- warnings.warn(DeprecationWarning('create_gpt2 is deprecated and will be removed in v0.18'))
-
- try:
- import transformers
- except ImportError as e:
- raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') from e
-
- if not model_config:
- model_config = {}
-
- if not pretrained_model_name:
- pretrained_model_name = 'gpt2'
-
- if use_pretrained:
- assert transformers.AutoModelForCausalLM.from_pretrained is not None, 'AutoModelForCausalLM has from_pretrained method'
- model = transformers.AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=pretrained_model_name,
- **model_config)
- else:
- config = transformers.AutoConfig.from_pretrained(pretrained_model_name, **model_config)
- assert transformers.AutoModelForCausalLM.from_config is not None, 'AutoModelForCausalLM has from_config method'
- model = transformers.AutoModelForCausalLM.from_config(config)
-
- if gradient_checkpointing:
- model.gradient_checkpointing_enable() # type: ignore
-
- # setup the tokenizer
- if tokenizer_name:
- tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
- else:
- tokenizer = None
-
- return HuggingFaceModel(model=model,
- tokenizer=tokenizer,
- metrics=[LanguageCrossEntropy(), LanguagePerplexity()],
- use_logits=True)
diff --git a/composer/models/huggingface.py b/composer/models/huggingface.py
index 9cf5939bfe..439f8b50fe 100644
--- a/composer/models/huggingface.py
+++ b/composer/models/huggingface.py
@@ -5,6 +5,7 @@
from __future__ import annotations
+import copy
import inspect
import json
import logging
@@ -13,24 +14,31 @@
import string
import tempfile
import textwrap
+import warnings
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Type, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, Tuple, Type, Union
import torch
from torchmetrics import Metric
-from composer.metrics import InContextLearningMetric, InContextLearningQAAccuracy
from composer.models.base import ComposerModel
from composer.utils import MissingConditionalImportError, dist, get_file, import_object, is_model_fsdp, safe_torch_load
+try:
+ from peft import PeftModel, get_peft_model
+ peft_installed = True
+except:
+ peft_installed = False
+
if TYPE_CHECKING:
import transformers
+ from peft import PeftConfig, PeftModel
from transformers import PretrainedConfig
from transformers.models.auto.auto_factory import _BaseAutoModelClass
log = logging.getLogger(__name__)
-__all__ = ['HuggingFaceModel']
+__all__ = ['HuggingFaceModel', 'peft_installed']
class HuggingFaceModel(ComposerModel):
@@ -38,7 +46,7 @@ class HuggingFaceModel(ComposerModel):
A wrapper class that converts đ¤ Transformers models to composer models.
Args:
- model (transformers.PreTrainedModel): A đ¤ Transformers model.
+ model (Union[transformers.PreTrainedModel, peft.PeftModel)): A đ¤ Transformers model or a PEFT model.
tokenizer (transformers.PreTrainedTokenizer, optional): The tokenizer used to prepare the dataset. Default ``None``.
.. note:: If the tokenizer is provided, its config will be saved in the composer checkpoint, and it can be reloaded
@@ -48,6 +56,8 @@ class HuggingFaceModel(ComposerModel):
eval_metrics (list[Metric], optional): list of torchmetrics to compute on the eval_dataloader, or be accessible to :class:`Evaluator`s. Default: ``None``.
shift_labels (bool, optional): If True, the batch's labels will be shifted before being used to calculate metrics. This should be set to true for CausalLM models and false otherwise. If not specified, `shift_labels` will be set automatically based on the model class name. Default: ``None``.
allow_embedding_resizing (bool, optional): If True, the model's embeddings will be automatically resized when they are smaller than the tokenizer vocab size. Default: ``False``.
+ peft_config (PeftConfig, optional): Optional PEFT config to apply to the model. If provided, the model will be converted to a PEFT model. Only LoRA is currently supported.
+ should_save_peft_only (bool, optional): If True _and_ PEFT is active, the state dict will only contain the PEFT weights, not the frozen base model weights.
.. note:: To ensure correct behavior, set `shift_labels` manually if using a custom model (i.e., if `model` is not
an instance of a registered đ¤ Transformers class).
@@ -66,14 +76,16 @@ class HuggingFaceModel(ComposerModel):
"""
def __init__(self,
- model: transformers.PreTrainedModel,
+ model: Union[transformers.PreTrainedModel, 'PeftModel'],
tokenizer: Optional[Union[transformers.PreTrainedTokenizer,
transformers.PreTrainedTokenizerFast]] = None,
use_logits: Optional[bool] = False,
metrics: Optional[List[Metric]] = None,
eval_metrics: Optional[List[Metric]] = None,
shift_labels: Optional[bool] = None,
- allow_embedding_resizing: bool = False) -> None:
+ allow_embedding_resizing: bool = False,
+ peft_config: Optional['PeftConfig'] = None,
+ should_save_peft_only: bool = True) -> None:
try:
import transformers
del transformers # unused
@@ -82,71 +94,118 @@ def __init__(self,
conda_package='transformers',
conda_channel='conda-forge') from e
+ if peft_config is not None:
+ if not peft_installed:
+ raise MissingConditionalImportError(extra_deps_group='peft',
+ conda_package='peft',
+ conda_channel='conda-forge')
+
+ if peft_config is not None:
+ # Hugging Face requires the peft type and task type to be upper case, so we do that here
+ # https://github.com/huggingface/peft/blob/ebbff4023ad276cbcb2466fd7e99be7d3ae0ae11/src/peft/utils/peft_types.py#L22-L51
+ if isinstance(peft_config.peft_type, str):
+ peft_config.peft_type = peft_config.peft_type.upper()
+ if isinstance(peft_config.task_type, str):
+ peft_config.task_type = peft_config.task_type.upper()
+
+ if peft_config.peft_type != 'LORA':
+ raise ValueError(
+ f'PEFT type {peft_config.peft_type} is not supported by HuggingFaceModel. Only LORA is supported.')
+
super().__init__()
self.model = model
- self.config = model.config
- self.model_forward_args = inspect.getfullargspec(self.model.forward).args
+ self.config: PretrainedConfig = model.config
+ self.model_forward_args = self._get_model_forward_args()
self.tokenizer = tokenizer
+ self.should_save_peft_only = should_save_peft_only
+ self.use_logits = use_logits
+ self.labels: Optional[torch.Tensor] = None # set in eval_forward() if exists
+ self.dummy_forward_called = False # Used to make FSDP generate work, see generate function for more details
+ self.train_metrics: Optional[Dict] = self._get_metric_dict(metrics) if metrics is not None else None
+ self.val_metrics: Optional[Dict] = self._get_metric_dict(
+ eval_metrics) if eval_metrics is not None else copy.deepcopy(self.train_metrics)
+
+ is_causal_lm = _is_registered_causal_lm(self.model)
+ self.shift_labels = is_causal_lm if shift_labels is None else shift_labels
+
+ self._check_tokenizer_and_maybe_resize_embeddings(allow_embedding_resizing)
+
+ if is_causal_lm and not self.shift_labels:
+ log.warning('The shift_labels argument was set to False but the model is an instance of a'
+ ' HuggingFace Causal LM. This may lead to incorrect behavior.')
+ # Note: No warning if shift_labels and not is_causal_lm, since the model may simply be a custom class.
+
+ if peft_config is not None:
+ self.model = _maybe_get_peft_model(peft_config, self.model)
+ self.using_peft = isinstance(self.model, PeftModel) if peft_installed else False
+
+ def _check_tokenizer_and_maybe_resize_embeddings(self, allow_embedding_resizing: bool) -> None:
if self.tokenizer is None:
log.warning(
'The tokenizer was not provided. This means the tokenizer config will not be saved in the checkpoint.')
- if tokenizer is not None and self.config.vocab_size < len(tokenizer):
+ if self.tokenizer is not None and self.config.vocab_size < len(self.tokenizer):
if allow_embedding_resizing:
# when the embedding size is smaller than the tokenizer vocab size,
# the embeddings should get resized to match the tokenizer vocab size
log.warning(f'The number of tokens in the tokenizer is greater than the number of tokens in the model.'
f' This would cause an error during training.'
- f' Resizing the model embeddings to {len(tokenizer)} from {self.config.vocab_size}.')
- self.model.resize_token_embeddings(len(tokenizer))
+ f' Resizing the model embeddings to {len(self.tokenizer)} from {self.config.vocab_size}.')
+ self.model.resize_token_embeddings(len(self.tokenizer))
else:
raise ValueError(
f'The number of tokens in the tokenizer is greater than the number of tokens in the model.'
f' This would cause an error during training.'
- f' You can resize the model embeddings to {len(tokenizer)} from {self.config.vocab_size}'
+ f' You can resize the model embeddings to {len(self.tokenizer)} from {self.config.vocab_size}'
f' by calling `model.resize_token_embeddings(len(tokenizer))` before calling the `HuggingFaceModel`'
f' constructor, or pass `allow_embedding_resizing=True` to have it done automatically.')
- elif tokenizer is not None and self.config.vocab_size > len(tokenizer):
+ elif self.tokenizer is not None and self.config.vocab_size > len(self.tokenizer):
# when the embedding size is greater than the tokenizer vocab size,
# the embeddings do not _need_ to be resized to match the tokenizer vocab size,
# and should be done by the user if desired
log.info(
f'The number of tokens in the tokenizer is less than the number of tokens in the model.'
- f' You may want to resize the model embeddings to {len(tokenizer)} from {self.config.vocab_size}'
+ f' You may want to resize the model embeddings to {len(self.tokenizer)} from {self.config.vocab_size}'
f' by calling `model.resize_token_embeddings(len(tokenizer))` before calling the `HuggingFaceModel`'
f' constructor. The vocab size is sometimes intentionally set to a multiple of 32 or 64 to improve'
f' performance.')
- self.use_logits = use_logits
+ def _get_metric_dict(self, metrics: List[Metric]) -> Dict[str, Metric]:
+ """Returns a dictionary of metrics keyed by their class name."""
+ return {metric.__class__.__name__: metric for metric in metrics}
- self.train_metrics: Optional[Dict] = None
- self.val_metrics: Optional[Dict] = None
+ def _get_model_forward_args(self) -> Set[str]:
+ """Returns the arguments to the model's forward function."""
+ model_forward_args = inspect.signature(maybe_get_underlying_model(self.model).forward).parameters.keys()
- if eval_metrics is not None:
- self.val_metrics = {metric.__class__.__name__: metric for metric in eval_metrics}
- if metrics is not None:
- self.train_metrics = {metric.__class__.__name__: metric for metric in metrics}
- # if eval_metrics is None, use the same metrics as train_metrics
- if eval_metrics is None:
- self.val_metrics = {metric.__class__.__name__: metric for metric in metrics}
+ if not model_forward_args:
+ raise ValueError('Could not determine the forward arguments of the model. Please open a GitHub issue.')
- self.labels: Optional[torch.Tensor] = None # set in eval_forward() if exists
+ model_forward_args = set(model_forward_args)
- is_causal_lm = _is_registered_causal_lm(model)
- self.shift_labels = is_causal_lm if shift_labels is None else shift_labels
- if is_causal_lm and not self.shift_labels:
- log.warning('The shift_labels argument was set to False but the model is an instance of a'
- ' HuggingFace Causal LM. This may lead to incorrect behavior.')
- # Note: No warning if shift_labels and not is_causal_lm, since the model may simply be a custom class.
+ return model_forward_args
+
+ def state_dict(self, *args, **kwargs) -> Dict[str, Any]:
+ """Returns the state dict of the model."""
+ full_state_dict = super().state_dict(*args, **kwargs)
- self.dummy_forward_called = False
+ if self.using_peft and self.should_save_peft_only:
+ active_adapter = self.model.active_adapter
+ assert isinstance(active_adapter, str)
+ full_state_dict = filter_state_dict_peft(full_state_dict,
+ self.model.peft_config[active_adapter],
+ adapter_name='default',
+ remove_adapter_names=False)
+
+ return full_state_dict
@staticmethod
def load_huggingface_tokenizer_from_saved_state(
- hf_state: Dict[str, Any],
- trust_remote_code: bool = False,
- tokenizer_save_dir: Optional[str] = None) -> Optional[transformers.PreTrainedTokenizer]:
+ hf_state: Dict[str, Any],
+ trust_remote_code: bool = False,
+ tokenizer_save_dir: Optional[str] = None
+ ) -> Optional[transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast]:
"""A helper function that loads a HuggingFace tokenizer from a loaded in hf state.
Args:
@@ -156,7 +215,7 @@ def load_huggingface_tokenizer_from_saved_state(
a folder with a unique suffix will be saved in the current working directory. Defaults to None.
Returns:
- Optional[transformers.PreTrainedTokenizer]: The loaded HuggingFace tokenizer
+ Optional[transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast]: The loaded HuggingFace tokenizer
"""
try:
import transformers
@@ -201,7 +260,7 @@ def load_huggingface_tokenizer_from_saved_state(
raise MissingConditionalImportError(extra_deps_group='sentencepiece',
conda_package='sentencepiece') from e
s = spm.SentencePieceProcessor()
- s.load_from_serialized_proto(saved_content['content'])
+ s.load_from_serialized_proto(saved_content['content']) # pyright: ignore[reportGeneralTypeIssues]
with open(tokenizer_file_path, 'wb') as _f:
_f.write(s.serialized_model_proto())
@@ -265,7 +324,8 @@ def load_huggingface_model_from_saved_state(
# pyright can't tell this isn't a string at this point
if issubclass(
model_instantiation_class, # type: ignore
- transformers.models.auto.auto_factory._BaseAutoModelClass):
+ transformers.models.auto.auto_factory._BaseAutoModelClass # type: ignore
+ ): # pyright: ignore[reportGeneralTypeIssues]
hf_model = model_instantiation_class.from_config(loaded_config) # type: ignore
else:
hf_model = model_instantiation_class(loaded_config) # type: ignore
@@ -291,7 +351,8 @@ def hf_from_composer_checkpoint(
model_config_kwargs: Optional[dict] = None,
local_checkpoint_save_location: Optional[Union[Path, str]] = None,
trust_remote_code: bool = False,
- ) -> Tuple[transformers.PreTrainedModel, Optional[transformers.PreTrainedTokenizer]]:
+ ) -> Tuple[transformers.PreTrainedModel, Optional[Union[transformers.PreTrainedTokenizer,
+ transformers.PreTrainedTokenizerFast]]]:
"""Loads a HuggingFace model (and tokenizer if present) from a composer checkpoint.
.. note:: This function does not load the weights from the checkpoint. It just loads the correctly configured
@@ -353,7 +414,7 @@ def hf_from_composer_checkpoint(
ValueError: If the ``model_instantiation_class``, or the model class saved in the checkpoint, is not able to be imported
Returns:
- Tuple[transformers.PreTrainedModel, Optional[transformers.PreTrainedTokenizer]]: The loaded HuggingFace model and (if present) tokenizer
+ Tuple[transformers.PreTrainedModel, Optional[Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast]]]: The loaded HuggingFace model and (if present) tokenizer
"""
# default local path to a tempfile if path is not provided
@@ -413,7 +474,8 @@ def eval_forward(self, batch, outputs: Optional[Any] = None):
**batch.get('generation_kwargs', {}))
# don't remove prefix space to sentencepiece models
- if len(self.tokenizer(' a', add_special_tokens=False)['input_ids']) == 1:
+ if len(self.tokenizer(
+ ' a', add_special_tokens=False)['input_ids']) == 1: # pyright: ignore[reportGeneralTypeIssues]
return self.tokenizer.batch_decode(generation[:, batch['input_ids'].shape[1]:],
skip_special_tokens=True)
else:
@@ -429,7 +491,7 @@ def eval_forward(self, batch, outputs: Optional[Any] = None):
# HF encoder decoder models like T5 expect either decoder_input_ids or labels,
# so we add decoder_input_ids to the batch if it is missing
- if self.model.config.is_encoder_decoder and 'decoder_input_ids' not in batch:
+ if self.config.is_encoder_decoder and 'decoder_input_ids' not in batch:
if hasattr(self.model, 'prepare_decoder_input_ids_from_labels'):
batch['decoder_input_ids'] = self.model.prepare_decoder_input_ids_from_labels(labels=self.labels)
else:
@@ -469,14 +531,10 @@ def get_metrics(self, is_train: bool = False) -> Dict[str, Metric]:
return metrics if metrics else {}
def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
- if isinstance(metric, InContextLearningQAAccuracy):
- assert self.labels is not None
- metric.update(batch=batch, outputs=outputs, labels=self.labels) # pyright: ignore [reportGeneralTypeIssues]
- elif isinstance(metric, InContextLearningMetric):
- assert self.labels is not None
- metric.update(batch, outputs, self.labels) # pyright: ignore [reportGeneralTypeIssues]
+ if getattr(metric, 'needs_batch', False):
+ metric.update(batch=batch, outputs=outputs, labels=self.labels)
else:
- metric.update(outputs, self.labels) # pyright: ignore [reportGeneralTypeIssues]
+ metric.update(outputs, self.labels)
def get_metadata(self):
model_output = {}
@@ -485,7 +543,9 @@ def get_metadata(self):
tmp_dir = Path(tmp_dir)
model_dir = tmp_dir / 'model'
tokenizer_dir = tmp_dir / 'tokenizer'
- self.model.config.save_pretrained(model_dir)
+
+ original_model_config: PretrainedConfig = self.config
+ original_model_config.save_pretrained(model_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(tokenizer_dir)
@@ -498,6 +558,19 @@ def get_metadata(self):
'class': f'{self.model.__class__.__module__}.{self.model.__class__.__name__}'
}
+ # Also save PEFT config if the model is a peft model
+ if self.using_peft:
+ active_adapter = self.model.active_adapter
+ assert isinstance(active_adapter, str)
+ self.model.peft_config[active_adapter].save_pretrained(str(model_dir))
+ with open(model_dir / 'adapter_config.json') as _peft_config_file:
+ peft_config = json.load(_peft_config_file)
+
+ model_output['peft_config'] = {
+ 'file_extension': '.json',
+ 'content': peft_config,
+ }
+
if self.tokenizer is not None:
for tokenizer_file_name in tokenizer_dir.iterdir():
tokenizer_file_path = tokenizer_dir / tokenizer_file_name
@@ -517,7 +590,8 @@ def get_metadata(self):
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='sentencepiece',
conda_package='sentencepiece') from e
- s = spm.SentencePieceProcessor(model_file=str(tokenizer_file_path))
+ s = spm.SentencePieceProcessor(
+ model_file=str(tokenizer_file_path)) # pyright: ignore[reportGeneralTypeIssues]
tokenizer_file_content = s.serialized_model_proto()
else:
raise ValueError(
@@ -542,25 +616,7 @@ def generate(self, input_ids: torch.Tensor, **kwargs):
"""
pad_token_id = kwargs.pop('pad_token_id', self.tokenizer.pad_token_id if self.tokenizer is not None else None)
- from composer.utils.misc import using_torch_2
-
- # We need to call forward once in order for FSDP + generate to work
- # This solution works because parameters in the root FSDP module are not freed after forward
- # See https://github.com/huggingface/accelerate/issues/570, https://github.com/huggingface/accelerate/issues/947,
- # and https://github.com/pytorch/pytorch/issues/82461, https://github.com/pytorch/pytorch/issues/100069 for more info
- # Note: This is a solution for Torch 1.13.x, and there is a different solution below for Torch 2.0
- if not using_torch_2() and not self.dummy_forward_called and is_model_fsdp(self.model):
- with torch.no_grad():
- maybe_decoder_input_ids = {}
- if self.model.config.is_encoder_decoder:
- maybe_decoder_input_ids['decoder_input_ids'] = torch.tensor([[0]],
- dtype=torch.long,
- device=input_ids.device)
- self.model(input_ids=torch.tensor([[0]], dtype=torch.long, device=input_ids.device),
- **maybe_decoder_input_ids)
- self.dummy_forward_called = True
-
- if is_model_fsdp(self.model) and using_torch_2():
+ if is_model_fsdp(self.model):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
# Note: We need to use the FSDP.summon_full_params context manager here because the generate function
@@ -574,7 +630,49 @@ def generate(self, input_ids: torch.Tensor, **kwargs):
return self.model.generate(input_ids=input_ids, pad_token_id=pad_token_id, **kwargs)
-def _is_registered_causal_lm(model: transformers.PreTrainedModel) -> bool:
+def _maybe_get_peft_model(
+ peft_config: 'PeftConfig',
+ model: Union[transformers.PreTrainedModel, 'PeftModel'],
+) -> 'PeftModel':
+ """Creates a PEFT model if the model is not already a PEFT model.
+
+ Args:
+ peft_config (Optional[peft.PeftConfig]): The PEFT config to use to create the PEFT model
+ model (Union[transformers.PreTrainedModel, 'PeftModel']): The model to create the PEFT model from
+
+ Returns:
+ PeftModel: The PEFT model
+ """
+ if not peft_installed:
+ raise MissingConditionalImportError(extra_deps_group='peft', conda_package='peft', conda_channel='conda-forge')
+
+ if not isinstance(model, PeftModel):
+ log.info('Creating PEFT model')
+ peft_model = get_peft_model(model, peft_config)
+ assert isinstance(peft_model, PeftModel)
+ return peft_model
+ else:
+ warnings.warn('PEFT model was passed in directly. Ignoring the provided PEFT config.')
+ return model
+
+
+def maybe_get_underlying_model(
+ model: Union[transformers.PreTrainedModel, 'PeftModel']) -> Union[transformers.PreTrainedModel, 'PeftModel']:
+ """Get the underlying PreTrainedModel from a model if it is a PEFT model
+
+ Args:
+ model (Union[transformers.PreTrainedModel, 'PeftModel']): The model to get the underlying model from
+
+ Returns:
+ Union[transformers.PreTrainedModel]: The underlying transformers model
+ """
+ if peft_installed and isinstance(model, PeftModel):
+ return model.base_model.model
+ else:
+ return model
+
+
+def _is_registered_causal_lm(model: Union[transformers.PreTrainedModel, 'PeftModel']) -> bool:
"""Return True if model class is either a registered đ¤ Causal LM or a subclass of one"""
try:
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING
@@ -583,6 +681,8 @@ def _is_registered_causal_lm(model: transformers.PreTrainedModel) -> bool:
conda_package='transformers',
conda_channel='conda-forge') from e
+ model_to_check = maybe_get_underlying_model(model)
+
# This try/except is needed until https://github.com/huggingface/transformers/issues/26778
# is resolved in a release. This means that this attempt to automatically detect causal LMs
# does not currently work in an environment with flash attention <2 installed.
@@ -594,7 +694,7 @@ def _is_registered_causal_lm(model: transformers.PreTrainedModel) -> bool:
return False
else:
raise e
- return any(isinstance(model, causal_lm_class) for causal_lm_class in causal_lm_classes)
+ return any(isinstance(model_to_check, causal_lm_class) for causal_lm_class in causal_lm_classes) # type: ignore
def get_hf_config_from_composer_state_dict(state_dict: Dict[str, Any],
@@ -637,6 +737,30 @@ def get_hf_config_from_composer_state_dict(state_dict: Dict[str, Any],
f'config has a valid `_name_or_path`.')
+def get_peft_config_from_composer_state_dict(state_dict: Dict[str, Any]) -> Optional['PeftConfig']:
+ """Get a PEFT config from a composer state dict
+
+ Args:
+ state_dict (Dict[str, Any]): The state dict to get the config from
+
+ Returns:
+ Optional[peft.PeftConfig]: The PEFT config. Will be ``None`` if the model is not a PEFT model.
+ """
+ try:
+ import peft
+ except ImportError as e:
+ raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='peft',
+ conda_channel='conda-forge') from e
+
+ hf_model_dict = state_dict['state']['integrations']['huggingface']['model']
+ if 'peft_config' not in hf_model_dict:
+ return None
+
+ peft_config_dict = hf_model_dict['peft_config']['content']
+
+ return peft.get_peft_config(peft_config_dict)
+
+
def write_huggingface_pretrained_from_composer_checkpoint(
checkpoint_path: Union[Path, str],
output_folder: Union[Path, str],
@@ -713,6 +837,61 @@ def write_huggingface_pretrained_from_composer_checkpoint(
config = get_hf_config_from_composer_state_dict(composer_state_dict)
config.save_pretrained(output_folder)
+ peft_config = get_peft_config_from_composer_state_dict(composer_state_dict)
+ if peft_config is not None:
+ peft_config.save_pretrained(str(output_folder))
+
weights_state_dict = composer_state_dict['state']['model']
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(weights_state_dict, prefix='model.')
- torch.save(weights_state_dict, Path(output_folder) / 'pytorch_model.bin')
+
+ # NOTE: This only works for default adapter name, not multiple adapters
+ if peft_config is not None:
+ weights_state_dict = filter_state_dict_peft(weights_state_dict, peft_config, adapter_name='default')
+
+ torch.save(weights_state_dict, Path(output_folder) / 'adapter_model.bin')
+ else:
+ torch.save(weights_state_dict, Path(output_folder) / 'pytorch_model.bin')
+
+
+def filter_state_dict_peft(state_dict: Dict[str, Any],
+ peft_config: 'PeftConfig',
+ adapter_name: str = 'default',
+ remove_adapter_names: bool = True) -> Dict[str, Any]:
+ """Filter a state dict to only include the weights needed for a PEFT model
+
+ Note: This function only works with LORA PEFT models right now.
+
+ Args:
+ state_dict (Dict[str, Any]): The state dict to filter
+ peft_config (PeftConfig): The PEFT config to use to filter the state dict
+ adapter_name (str, optional): The name of the adapter to filter for. Defaults to 'default'.
+ remove_adapter_names (bool, optional): Whether to remove the adapter names from the state dict keys. Defaults to True.
+
+ Returns:
+ Dict[str, Any]: The filtered state dict
+ """
+
+ if peft_config.peft_type != 'LORA':
+ raise NotImplementedError(f'Only LoRA PEFT is supported. Got {peft_config.peft_type}')
+
+ # Filtering copied from https://github.com/huggingface/peft/blob/4186c9b104644fd247a4cc0dc2dfc1ede4665204/src/peft/utils/save_and_load.py#L68C1-L86C116
+ bias = peft_config.bias # type: ignore
+ if bias == 'none':
+ to_return = {k: state_dict[k] for k in state_dict if 'lora_' in k}
+ elif bias == 'all':
+ to_return = {k: state_dict[k] for k in state_dict if 'lora_' in k or 'bias' in k}
+ elif bias == 'lora_only':
+ to_return = {}
+ for k in state_dict:
+ if 'lora_' in k:
+ to_return[k] = state_dict[k]
+ bias_name = k.split('lora_')[0] + 'bias'
+ if bias_name in state_dict:
+ to_return[bias_name] = state_dict[bias_name]
+ else:
+ raise NotImplementedError
+ to_return = {k: v for k, v in to_return.items() if (('lora_' in k and adapter_name in k) or ('bias' in k))}
+
+ if remove_adapter_names:
+ to_return = {k.replace(f'.{adapter_name}', ''): v for k, v in to_return.items()}
+ return to_return
diff --git a/composer/models/mmdetection.py b/composer/models/mmdetection.py
deleted file mode 100644
index 2e53aac543..0000000000
--- a/composer/models/mmdetection.py
+++ /dev/null
@@ -1,126 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""A wrapper class that converts mmdet detection models to composer models"""
-
-from __future__ import annotations
-
-import warnings
-from typing import TYPE_CHECKING, Any, List, Optional
-
-import numpy as np
-import torch
-from torchmetrics import Metric
-from torchmetrics.collections import MetricCollection
-
-from composer.models import ComposerModel
-
-if TYPE_CHECKING:
- import mmdet
-
-__all__ = ['MMDetModel']
-
-
-class MMDetModel(ComposerModel):
- """A wrapper class that adapts mmdetection detectors to composer models.
-
- Args:
- model (mmdet.models.detectors.BaseDetector): An MMdetection Detector.
- metrics (list[Metric], optional): list of torchmetrics to apply to the output of `eval_forward`. Default: ``None``.
-
- .. warning:: This wrapper is designed to work with mmdet datasets.
-
- Example:
-
- .. code-block:: python
-
- from mmdet.models import build_model
- from mmcv import ConfigDict
- from composer.models import MMDetModel
-
- yolox_s_config = dict(
- type='YOLOX',
- input_size=(640, 640),
- random_size_range=(15, 25),
- random_size_interval=10,
- backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
- neck=dict(type='YOLOXPAFPN', in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1),
- bbox_head=dict(type='YOLOXHead', num_classes=num_classes, in_channels=128, feat_channels=128),
- train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
- test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))
- yolox = build_model(ConfigDict(yolox_s_config))
- yolox.init_weights()
- model = MMDetModel(yolox)
- """
-
- def __init__(
- self,
- model: mmdet.models.detectors.BaseDetector, # type: ignore
- metrics: Optional[List[Metric]] = None) -> None:
- warnings.warn(DeprecationWarning('MMDetModel is deprecated and will be removed in v0.18'))
- super().__init__()
- self.model = model
-
- self.train_metrics = None
- self.val_metrics = None
-
- if metrics:
- metric_collection = MetricCollection(metrics)
- self.train_metrics = metric_collection.clone(prefix='train_')
- self.val_metrics = metric_collection.clone(prefix='val_')
-
- def forward(self, batch):
- # this will return a dictionary of losses in train mode and model outputs in test mode.
- return self.model(**batch)
-
- def loss(self, outputs, batch, **kwargs):
- return outputs
-
- def eval_forward(self, batch, outputs: Optional[Any] = None):
- """
- Args:
- batch (dict): a eval batch of the format:
-
-
- ``img`` (List[torch.Tensor]): list of image torch.Tensors of shape (batch, c, h , w).
-
-
- ``img_metas`` (List[Dict]): (1, batch_size) list of ``image_meta`` dicts.
- Returns: model predictions: A batch_size length list of dictionaries containg detection boxes in (x,y, x2, y2) format, class labels, and class probabilities.
- """
- device = batch['img'][0].device
- batch.pop('gt_labels')
- batch.pop('gt_bboxes')
- results = self.model(return_loss=False, rescale=True, **batch) # models behave differently in eval mode
-
- # outputs are a list of bbox results (x, y, x2, y2, score)
- # pack mmdet bounding boxes and labels into the format for torchmetrics MAP expects
- preds = []
- for bbox_result in results:
- boxes_scores = np.vstack(bbox_result)
- boxes, scores = torch.from_numpy(boxes_scores[..., :-1]).to(device), torch.from_numpy(
- boxes_scores[..., -1]).to(device)
- labels = [np.full(result.shape[0], i, dtype=np.int32) for i, result in enumerate(bbox_result)]
- pred = {
- 'labels': torch.from_numpy(np.concatenate(labels)).to(device).long(),
- 'boxes': boxes.float(),
- 'scores': scores.float()
- }
- preds.append(pred)
- return preds
-
- def get_metrics(self, is_train: bool = False):
- if is_train:
- metrics = self.train_metrics
- else:
- metrics = self.val_metrics
- return metrics if metrics else {}
-
- def update_metric(self, batch: Any, outputs: Any, metric: Metric):
- targets_box = batch.pop('gt_bboxes')[0]
- targets_cls = batch.pop('gt_labels')[0]
- targets = []
- for i in range(len(targets_box)):
- t = {'boxes': targets_box[i], 'labels': targets_cls[i]}
- targets.append(t)
- metric.update(outputs, targets)
diff --git a/composer/models/resnet/README.md b/composer/models/resnet/README.md
deleted file mode 100644
index 430dd303b4..0000000000
--- a/composer/models/resnet/README.md
+++ /dev/null
@@ -1,69 +0,0 @@
-# đī¸ ResNet
-[\[How to Use\]](#how-to-use) · [\[Architecture\]](#architecture) · [\[Family Members\]](#family-members) · [\[Default Training Hyperparameters\]](#default-training-hyperparameters) · [\[Attribution\]](#attribution) · [\[API Reference\]](#api-reference)
-
-`Vision` / `Image Classification`
-
-The ResNet model family is a set of convolutional neural networks that can be used as a basis for a variety of vision tasks. Our implementation is a simple wrapper on top of the [torchvision ResNet implementation](https://pytorch.org/vision/stable/models.html).
-
-## How to Use
-
-```python
-from composer.models import composer_resnet
-
-model = composer_resnet(
- model_name="resnet50",
- num_classes=1000,
- weights=None
-)
-```
-
-## Architecture
-
-The basic architecture defined in the original papers is as follows:
-
-- The first layer is a 7x7 Convolution with stride 2 and 64 filters.
-- Subsequent layers follow 4 stages with {64, 128, 256, 512} input channels with a varying number of residual blocks at each stage that depends on the family member. At the end of every stage, the resolution is reduced by half using a convolution with stride 2.
-- The final section consists of a global average pooling followed by a linear + softmax layer that outputs values for the specified number of classes.
-
-The below table from [He et al.](https://arxiv.org/abs/1512.03385) details some of the building blocks for ResNets of different sizes.
-
-![resnet.png](https://storage.googleapis.com/docs.mosaicml.com/images/models/resnet.png)
-
-## Family Members
-
-ResNet family members are identified by their number of layers. Parameter count, accuracy, and training time are provided below.
-
-| Model Family Members | Parameter Count | Our Accuracy | Training Time on 8xA100s |
-|----------------------|-----------------|--------------|--------------------------|
-| ResNet-18 | 11.5M | TBA | TBA |
-| ResNet-34 | 21.8M | TBA | TBA |
-| ResNet-50 | 25.6M | 76.5% | 3.83 hrs |
-| ResNet-101 | 44.5M | 78.1% | 5.50 hrs |
-| ResNet-152 | 60.2M | TBA | TBA |
-
-
-> â **Note**: Please see the [CIFAR ResNet model card](https://docs.mosaicml.com/projects/composer/en/stable/model_cards/cifar_resnet.html#architecture) for the differences between CIFAR and ImageNet ResNets.
-
-## Default Training Hyperparameters
-
-- Optimizer: Decoupled SGDW
- - Learning rate: 2.048
- Momentum: 0.875
- Weight_decay: 5.0e-4
-- LR schedulers:
- - Cosine decay with warmup for 8 epochs
-- Batch size: 2048
-- Number of epochs: 90ep
-
-## Attribution
-
-Paper: [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
-
-Code and hyperparameters: [DeepLearningExamples Github repository](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/resnet50v1.5) by Nvidia
-
-## API Reference
-
-```{eval-rst}
-.. autofunction:: composer.models.resnet.model.composer_resnet
- :noindex:
-```
diff --git a/composer/models/resnet/__init__.py b/composer/models/resnet/__init__.py
deleted file mode 100644
index e00a37035b..0000000000
--- a/composer/models/resnet/__init__.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""The ResNet model family is a set of convolutional neural networks described in `Deep Residual Learning for Image
-Recognition `_ (He et al, 2015). ResNets can be used as the base for a variety of
-vision tasks. ImageNet ResNets are a subset of the ResNet family which were designed specifically for classification on
-the ImageNet dataset.
-
-See the :doc:`Model Card ` for more details.
-"""
-from composer.models.resnet.model import composer_resnet
-
-__all__ = ['composer_resnet']
-
-_metadata = {
- 'resnet18': {
- '_task': 'Image Classification',
- '_dataset': 'ImageNet',
- '_name': 'ResNet18',
- '_quality': 'TBD',
- '_metric': 'Top-1 Accuracy',
- '_ttt': 'TBD',
- '_hparams': 'resnet18.yaml'
- },
- 'resnet34': {
- '_task': 'Image Classification',
- '_dataset': 'ImageNet',
- '_name': 'ResNet34',
- '_quality': 'TBD',
- '_metric': 'Top-1 Accuracy',
- '_ttt': 'TBD',
- '_hparams': 'resnet34.yaml'
- },
- 'resnet50': {
- '_task': 'Image Classification',
- '_dataset': 'ImageNet',
- '_name': 'ResNet50',
- '_quality': '76.51',
- '_metric': 'Top-1 Accuracy',
- '_ttt': '3h 33m',
- '_hparams': 'resnet50.yaml'
- },
- 'resnet101': {
- '_task': 'Image Classification',
- '_dataset': 'ImageNet',
- '_name': 'ResNet101',
- '_quality': '78.10',
- '_metric': 'Top-1 Accuracy',
- '_ttt': '8h 15m',
- '_hparams': 'resnet101.yaml',
- },
- 'resnet152': {
- '_task': 'Image Classification',
- '_dataset': 'ImageNet',
- '_name': 'ResNet152',
- '_quality': 'TBD',
- '_metric': 'Top-1 Accuracy',
- '_ttt': 'TBD',
- '_hparams': 'resnet152.yaml'
- }
-}
diff --git a/composer/models/resnet/model.py b/composer/models/resnet/model.py
deleted file mode 100644
index 5b023fabcf..0000000000
--- a/composer/models/resnet/model.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""A :class:`.ComposerClassifier` wrapper around the torchvision implementations of the ResNet model family."""
-
-import logging
-import warnings
-from typing import List, Optional
-
-from torchmetrics import MetricCollection
-from torchmetrics.classification import MulticlassAccuracy
-from torchvision.models import resnet
-
-from composer.loss import loss_registry
-from composer.metrics import CrossEntropy
-from composer.models.initializers import Initializer
-from composer.models.tasks import ComposerClassifier
-
-__all__ = ['composer_resnet']
-
-log = logging.getLogger(__name__)
-
-valid_model_names = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
-
-
-def composer_resnet(model_name: str,
- num_classes: int = 1000,
- weights: Optional[str] = None,
- groups: int = 1,
- width_per_group: int = 64,
- initializers: Optional[List[Initializer]] = None,
- loss_name: str = 'soft_cross_entropy') -> ComposerClassifier:
- """Helper function to create a :class:`.ComposerClassifier` with a torchvision ResNet model.
-
- From `Deep Residual Learning for Image Recognition `_ (He et al, 2015).
-
- Args:
- model_name (str): Name of the ResNet model instance. Either [``"resnet18"``, ``"resnet34"``, ``"resnet50"``, ``"resnet101"``,
- ``"resnet152"``].
- num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``1000``.
- weights (str, optional): If provided, pretrained weights can be specified, such as with ``IMAGENET1K_V2``. Default: ``None``.
- groups (int, optional): Number of filter groups for the 3x3 convolution layer in bottleneck blocks. Default: ``1``.
- width_per_group (int, optional): Initial width for each convolution group. Width doubles after each stage.
- Default: ``64``.
- initializers (List[Initializer], optional): Initializers for the model. ``None`` for no initialization.
- Default: ``None``.
- loss_name (str, optional): Loss function to use. E.g. 'soft_cross_entropy' or
- 'binary_cross_entropy_with_logits'. Loss function must be in
- :mod:`~composer.loss.loss`. Default: ``'soft_cross_entropy'``".
- Returns:
- ComposerModel: instance of :class:`.ComposerClassifier` with a torchvision ResNet model.
-
- Example:
-
- .. testcode::
-
- from composer.models import composer_resnet
-
- model = composer_resnet(model_name='resnet18') # creates a torchvision resnet18 for image classification
- """
- warnings.warn(DeprecationWarning('composer_resnet is deprecated and will be removed in v0.18'))
-
- valid_model_names = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
- if model_name not in valid_model_names:
- raise ValueError(f'model_name must be one of {valid_model_names} instead of {model_name}.')
-
- if loss_name not in loss_registry.keys():
- raise ValueError(f'Unrecognized loss function: {loss_name}. Please ensure the '
- 'specified loss function is present in composer.loss.loss.py')
-
- if loss_name == 'binary_cross_entropy_with_logits' and (initializers is None or
- Initializer.LINEAR_LOG_CONSTANT_BIAS not in initializers):
- log.warning('UserWarning: Using `binary_cross_entropy_loss_with_logits` '
- 'without using `initializers.linear_log_constant_bias` can degrade '
- 'performance. '
- 'Please ensure you are using `initializers. '
- 'linear_log_constant_bias`.')
-
- if initializers is None:
- initializers = []
-
- # Instantiate model
- model_fn = getattr(resnet, model_name)
- model = model_fn(weights=weights, num_classes=num_classes, groups=groups, width_per_group=width_per_group)
-
- # Grab loss function from loss registry
- loss_fn = loss_registry[loss_name]
-
- # Create metrics for train and validation
- train_metrics = MulticlassAccuracy(num_classes=num_classes, average='micro')
- val_metrics = MetricCollection([CrossEntropy(), MulticlassAccuracy(num_classes=num_classes, average='micro')])
-
- # Apply Initializers to model
- for initializer in initializers:
- initializer = Initializer(initializer)
- model.apply(initializer.get_initializer())
-
- composer_model = ComposerClassifier(model, train_metrics=train_metrics, val_metrics=val_metrics, loss_fn=loss_fn)
- return composer_model
diff --git a/composer/models/resnet_cifar/README.md b/composer/models/resnet_cifar/README.md
deleted file mode 100644
index 5a32ae03b8..0000000000
--- a/composer/models/resnet_cifar/README.md
+++ /dev/null
@@ -1,73 +0,0 @@
-# CIFAR ResNet
-[\[Example\]](#example) · [\[Architecture\]](#architecture) · [\[Family Members\]](#family-members) · [\[Default Training Hyperparameters\]](#default-training-hyperparameters) · [\[Attribution\]](#attribution) · [\[API Reference\]](#api-reference)
-
-`Vision` / `Image Classification`
-
-The ResNet model family is a set of convolutional neural networks that can be used as the basis for a variety of vision tasks. CIFAR ResNet models are a subset of this family designed specifically for the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) and [CIFAR-100](https://www.cs.toronto.edu/~kriz/cifar.html) datasets.
-
-## Example
-
-```python
-from composer.models import composer_resnet_cifar
-
-model = composer_resnet_cifar(model_name='resnet_56', num_classes=10)
-```
-
-## Architecture
-
-Residual Networks are feedforward convolutional networks with âresidualâ connections between non-consecutive layers.
-
-The model architecture is defined by the original paper:
-
-- The network inputs are of dimension 32Ã32x3.
-- The first layer is 3Ã3 convolutions
-- The subsequent layers are a stack of 6n layers with 3Ã3 convolutions on the feature maps of sizes {32,16,8}, with 2n layers for each feature map size. The number of filters are {16,32,64} for the respective feature map sizes. Subsampling is performed by convolutions with a stride of 2
-- The network ends with a global average pooling, a linear layer with the output dimension equal to the number of classes, and softmax function.
-
-There are a total 6n+2 stacked weighted layers. Each family member is specified by the number of layers, for example n=9 corresponds to ResNet56
-
-The biggest differences between CIFAR ResNet models and ImageNet ResNet models are:
-
-- CIFAR ResNet models use fewer filters for each convolution.
-- The ImageNet ResNets contain four stages, while the CIFAR ResNets contain three stages. In addition, CIFAR ResNets uniformly distribute blocks across each stage while ImageNet ResNets have a specific number of blocks for each stage.
-
-## Family Members
-
-| Model Family Members | Parameter Count | Our Accuracy | Training Time on 1x3080 |
-|----------------------|-----------------|--------------|-------------------------|
-| ResNet20 | 0.27M | TBA | TBA |
-| ResNet32 | 0.46M | TBA | TBA |
-| ResNet44 | 0.66M | TBA | TBA |
-| ResNet56 | 0.85M | 93.1% | 35 min |
-| ResNet110 | 1.7M | TBA | TBA |
-## Default Training Hyperparameters
-
-```yaml
-optimizer:
- sgd:
- learning_rate: 1.2
- momentum: 0.9
- weight_decay: 1e-4
-schedulers:
- - multistep_with_warmup:
- t_warmup: "5ep"
- milestones:
- - "80ep"
- - "120ep"
- gamma: 0.1
-train_batch_size: 1024
-max_duration: 160ep
-```
-
-## Attribution
-
-Paper: [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
-
-Note that this paper set the standard for ResNet style architectures for both CIFAR-10/100 and ImageNet
-
-## API Reference
-
-```{eval-rst}
-.. autoclass:: composer.models.resnet_cifar.model.composer_resnet_cifar
- :noindex:
-```
diff --git a/composer/models/resnet_cifar/__init__.py b/composer/models/resnet_cifar/__init__.py
deleted file mode 100644
index 2ea6ac226c..0000000000
--- a/composer/models/resnet_cifar/__init__.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""A ResNet model family adapted for CIFAR10 image sizes.
-
-See the :doc:`Model Card ` for more details.
-"""
-
-from composer.models.resnet_cifar.model import composer_resnet_cifar as composer_resnet_cifar
-
-__all__ = ['composer_resnet_cifar']
-_metadata = {
- 'resnet9': {
- '_task': 'Image Classification',
- '_dataset': 'CIFAR10',
- '_name': 'ResNet9',
- '_quality': 'tbd',
- '_metric': 'Top-1 Accuracy',
- '_ttt': 'tbd',
- '_hparams': 'resnet9_cifar10.yaml'
- },
- 'resnet20': {
- '_task': 'Image Classification',
- '_dataset': 'CIFAR10',
- '_name': 'ResNet20',
- '_quality': 'tbd',
- '_metric': 'Top-1 Accuracy',
- '_ttt': 'tbd',
- '_hparams': 'resnet20_cifar10.yaml'
- },
- 'resnet56': {
- '_task': 'Image Classification',
- '_dataset': 'CIFAR10',
- '_name': 'ResNet56',
- '_quality': '93.1',
- '_metric': 'Top-1 Accuracy',
- '_ttt': '35m',
- '_hparams': 'resnet56_cifar10.yaml'
- }
-}
diff --git a/composer/models/resnet_cifar/model.py b/composer/models/resnet_cifar/model.py
deleted file mode 100644
index 5bb8660b56..0000000000
--- a/composer/models/resnet_cifar/model.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""ResNet models for CIFAR extending :class:`.ComposerClassifier`."""
-
-import warnings
-from typing import List, Optional
-
-from composer.models.initializers import Initializer
-from composer.models.resnet_cifar.resnets import ResNet9, ResNetCIFAR
-from composer.models.tasks import ComposerClassifier
-
-__all__ = ['composer_resnet_cifar']
-
-
-def composer_resnet_cifar(model_name: str,
- num_classes: int = 10,
- initializers: Optional[List[Initializer]] = None) -> ComposerClassifier:
- """Helper function to create a :class:`.ComposerClassifier` with a CIFAR ResNet models.
-
- From `Deep Residual Learning for Image Recognition `_ (He et al, 2015).
- ResNet9 is based on the model from myrtle.ai `blog`_.
-
- Args:
- model_name (str): ``"resnet_9"``, ``"resnet_20"``, or ``"resnet_56"``.
- num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``10``.
- initializers (List[Initializer], optional): Initializers for the model. ``None`` for no initialization.
- Default: ``None``.
-
- Returns:
- ComposerModel: instance of :class:`.ComposerClassifier` with a CIFAR ResNet model.
-
- Example:
-
- .. testcode::
-
- from composer.models import composer_resnet_cifar
-
- model = composer_resnet_cifar(model_name="resnet_56") # creates a resnet56 for cifar image classification
-
- .. _blog: https://myrtle.ai/learn/how-to-train-your-resnet-4-architecture/
- """
- warnings.warn(DeprecationWarning('composer_resnet_cifar is deprecated and will be removed in v0.18'))
- if initializers is None:
- initializers = []
-
- if model_name == 'resnet_9':
- model = ResNet9(num_classes) # current initializers don't work with this architecture.
- else:
- model = ResNetCIFAR.get_model_from_name(model_name, initializers, num_classes)
-
- composer_model = ComposerClassifier(module=model, num_classes=num_classes)
- return composer_model
diff --git a/composer/models/resnet_cifar/resnets.py b/composer/models/resnet_cifar/resnets.py
deleted file mode 100644
index b4f1576b46..0000000000
--- a/composer/models/resnet_cifar/resnets.py
+++ /dev/null
@@ -1,170 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""The CIFAR ResNet torch module.
-
-See the :doc:`Model Card ` for more details.
-"""
-
-# Code below adapted from https://github.com/facebookresearch/open_lth
-# and https://github.com/pytorch/vision
-
-from typing import List, Tuple
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torchvision.models.resnet import BasicBlock
-
-from composer.models import Initializer
-
-__all__ = ['ResNetCIFAR', 'ResNet9']
-
-
-class ResNetCIFAR(nn.Module):
- """A residual neural network as originally designed for CIFAR-10."""
-
- class Block(nn.Module):
- """A ResNet block."""
-
- def __init__(self, f_in: int, f_out: int, downsample: bool = False):
- super(ResNetCIFAR.Block, self).__init__()
-
- stride = 2 if downsample else 1
- self.conv1 = nn.Conv2d(f_in, f_out, kernel_size=3, stride=stride, padding=1, bias=False)
- self.bn1 = nn.BatchNorm2d(f_out)
- self.conv2 = nn.Conv2d(f_out, f_out, kernel_size=3, stride=1, padding=1, bias=False)
- self.bn2 = nn.BatchNorm2d(f_out)
- self.relu = nn.ReLU(inplace=True)
-
- # No parameters for shortcut connections.
- if downsample or f_in != f_out:
- self.shortcut = nn.Sequential(
- nn.Conv2d(f_in, f_out, kernel_size=1, stride=2, bias=False),
- nn.BatchNorm2d(f_out),
- )
- else:
- self.shortcut = nn.Sequential()
-
- def forward(self, x: torch.Tensor):
- out = self.relu(self.bn1(self.conv1(x)))
- out = self.bn2(self.conv2(out))
- out += self.shortcut(x)
- return self.relu(out)
-
- def __init__(self, plan: List[Tuple[int, int]], initializers: List[Initializer], outputs: int = 10):
- super(ResNetCIFAR, self).__init__()
- outputs = outputs or 10
-
- self.num_classes = outputs
-
- # Initial convolution.
- current_filters = plan[0][0]
- self.conv = nn.Conv2d(3, current_filters, kernel_size=3, stride=1, padding=1, bias=False)
- self.bn = nn.BatchNorm2d(current_filters)
- self.relu = nn.ReLU(inplace=True)
-
- # The subsequent blocks of the ResNet.
- blocks = []
- for segment_index, (filters, num_blocks) in enumerate(plan):
- for block_index in range(num_blocks):
- downsample = segment_index > 0 and block_index == 0
- blocks.append(ResNetCIFAR.Block(current_filters, filters, downsample))
- current_filters = filters
-
- self.blocks = nn.Sequential(*blocks)
-
- # Final fc layer. Size = number of filters in last segment.
- self.fc = nn.Linear(plan[-1][0], outputs)
- self.criterion = nn.CrossEntropyLoss()
-
- for initializer in initializers:
- initializer = Initializer(initializer)
- self.apply(initializer.get_initializer())
-
- def forward(self, x: torch.Tensor):
- out = self.relu(self.bn(self.conv(x)))
- out = self.blocks(out)
- out = F.avg_pool2d(out, out.size()[3])
- out = out.view(out.size(0), -1)
- out = self.fc(out)
- return out
-
- @staticmethod
- def is_valid_model_name(model_name: str):
- valid_model_names = [f'resnet_{layers}' for layers in (20, 56)]
- return (model_name in valid_model_names)
-
- @staticmethod
- def get_model_from_name(model_name: str, initializers: List[Initializer], outputs: int = 10):
- """The naming scheme for a ResNet is ``'resnet_D[_W]'``.
-
- D is the model depth (e.g. ``'resnet_56'``)
- """
-
- if not ResNetCIFAR.is_valid_model_name(model_name):
- raise ValueError('Invalid model name: {}'.format(model_name))
-
- depth = int(model_name.split('_')[-1]) # for resnet56, depth 56, width 16
- if len(model_name.split('_')) == 2:
- width = 16
- else:
- width = int(model_name.split('_')[3])
-
- if (depth - 2) % 3 != 0:
- raise ValueError('Invalid ResNetCIFAR depth: {}'.format(depth))
- num_blocks = (depth - 2) // 6
-
- model_arch = {
- 56: [(width, num_blocks), (2 * width, num_blocks), (4 * width, num_blocks)],
- 20: [(width, num_blocks), (2 * width, num_blocks), (4 * width, num_blocks)],
- }
-
- return ResNetCIFAR(model_arch[depth], initializers, outputs)
-
-
-# adapted from https://raw.githubusercontent.com/matthias-wright/cifar10-resnet/master/model.py
-# under the MIT license
-class ResNet9(nn.Module):
- """A 9-layer residual network, excluding BatchNorms and activation functions.
-
- Based on the myrtle.ai `blog`_ and Deep Residual Learning for Image Recognition (`He et al, 2015`_).
-
- Args:
- num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``10``.
-
- .. _blog: https://myrtle.ai/learn/how-to-train-your-resnet-4-architecture/
- .. _He et al, 2015: https://arxiv.org/abs/1512.03385
- """
-
- def __init__(self, num_classes: int = 10):
- super().__init__()
-
- self.body = nn.Sequential(
- nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(num_features=64, momentum=0.9),
- nn.ReLU(inplace=True),
- nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(num_features=128, momentum=0.9),
- nn.ReLU(inplace=True),
- nn.MaxPool2d(kernel_size=2, stride=2),
- BasicBlock(inplanes=128, planes=128, stride=1),
- nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(num_features=256, momentum=0.9),
- nn.ReLU(inplace=True),
- nn.MaxPool2d(kernel_size=2, stride=2),
- nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(num_features=256, momentum=0.9),
- nn.ReLU(inplace=True),
- nn.MaxPool2d(kernel_size=2, stride=2),
- BasicBlock(inplanes=256, planes=256, stride=1),
- )
-
- self.fc = nn.Linear(in_features=256, out_features=num_classes, bias=True)
-
- def forward(self, x):
- out = self.body(x)
- out = F.avg_pool2d(out, out.size()[3])
- out = out.view(out.size(0), -1)
- out = self.fc(out)
- return out
diff --git a/composer/models/timm/__init__.py b/composer/models/timm/__init__.py
deleted file mode 100644
index b7960b426a..0000000000
--- a/composer/models/timm/__init__.py
+++ /dev/null
@@ -1,9 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""A wrapper around `timm.create_model() `_
-used to create :class:`.ComposerClassifier`."""
-
-from composer.models.timm.model import composer_timm as composer_timm
-
-__all__ = ['composer_timm']
diff --git a/composer/models/timm/model.py b/composer/models/timm/model.py
deleted file mode 100644
index df0ffbca91..0000000000
--- a/composer/models/timm/model.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""A wrapper around `timm.create_model() `_
-used to create :class:`.ComposerClassifier`."""
-
-import warnings
-from typing import Optional
-
-from composer.models.tasks import ComposerClassifier
-from composer.utils.import_helpers import MissingConditionalImportError
-
-__all__ = ['composer_timm']
-
-
-def composer_timm(model_name: str,
- pretrained: bool = False,
- num_classes: int = 1000,
- drop_rate: float = 0.0,
- drop_path_rate: Optional[float] = None,
- drop_block_rate: Optional[float] = None,
- global_pool: Optional[str] = None,
- bn_momentum: Optional[float] = None,
- bn_eps: Optional[float] = None) -> ComposerClassifier:
- """A wrapper around `timm.create_model() `_ used to create :class:`.ComposerClassifier`.
-
- Args:
- model_name (str): timm model name e.g: ``"resnet50"``. List of models can be found at
- `PyTorch Image Models `_.
- pretrained (bool, optional): Imagenet pretrained. Default: ``False``.
- num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``1000``.
- drop_rate (float, optional): Dropout rate. Default: ``0.0``.
- drop_path_rate (float, optional): Drop path rate (model default if ``None``). Default: ``None``.
- drop_block_rate (float, optional): Drop block rate (model default if ``None``). Default: ``None``.
- global_pool (str, optional): Global pool type, one of (``"fast"``, ``"avg"``, ``"max"``, ``"avgmax"``, ``"avgmaxc"``). Model default if ``None``. Default: ``None``.
- bn_momentum (float, optional): BatchNorm momentum override (model default if ``None``). Default: ``None``.
- bn_eps (float, optional): BatchNorm epsilon override (model default if ``None``). Default: ``None``.
-
- Returns:
- ComposerModel: instance of :class:`.ComposerClassifier` with the specified TIMM model.
-
- Resnet18 Example:
-
- .. testcode::
-
- from composer.models import composer_timm
-
- model = composer_timm(model_name='resnet18') # creates a timm resnet18
- """
- warnings.warn(DeprecationWarning('composer_timm is deprecated and will be removed in v0.18'))
- try:
- import timm
- except ImportError as e:
- raise MissingConditionalImportError(extra_deps_group='timm', conda_package='timm>=0.5.4',
- conda_channel=None) from e
- model = timm.create_model( # type: ignore (third-party)
- model_name=model_name,
- pretrained=pretrained,
- num_classes=num_classes,
- drop_rate=drop_rate,
- drop_path_rate=drop_path_rate,
- drop_block_rate=drop_block_rate,
- global_pool=global_pool,
- bn_momentum=bn_momentum,
- bn_eps=bn_eps)
-
- composer_model = ComposerClassifier(module=model)
- return composer_model
diff --git a/composer/models/unet/README.md b/composer/models/unet/README.md
deleted file mode 100644
index 530832051b..0000000000
--- a/composer/models/unet/README.md
+++ /dev/null
@@ -1,62 +0,0 @@
-# UNet
-[\[Example\]](#example) · [\[Architecture\]](#architecture) · [\[Default Training Hyperparameters\]](#default-training-hyperparameters) · [\[Attribution\]](#attribution) · [\[API Reference\]](#api-reference)
-
-`Vision` / `Segmentation`
-
-Unet is an architecture used for image segmentation.
-
-## Example
-
-
-
-```python
-from composer.models import UNet
-
-model = UNet()
-```
-
-## Architecture
-
-The figure below ([source](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Segmentation/nnUNet)) shows a 3D version of the UNet architecture. Quoting the [Nvidia Deep Learning Examples](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Segmentation/nnUNet), "U-Net is composed of a contractive and an expanding path, that aims at building a bottleneck in its centremost part through a combination of convolution, instance norm and leaky relu operations. After this bottleneck, the image is reconstructed through a combination of convolutions and upsampling. Skip connections are added with the goal of helping the backward flow of gradients in order to improve training."
-
-![unet3d.png](https://storage.googleapis.com/docs.mosaicml.com/images/models/unet3d.png)
-
-
-There are 3 main differences between our implementation and the original NVDA DALI implementation.
-
-The first two refer to removing the NVDA DALI pipeline and replacing all transforms with torch implementations. We are omitting the Zoom transform and use a kernel size of 3 for the Gaussian Blur transform.
-
-While NVDA DLE examples reports the training accuracy using an average of 5 folds, we are using only 1 fold in the interest of faster iteration time, so all of our results are reported using fold 0 and 200 epochs.
-
-
-## Default Training Hyperparameters
-
-Below are the hyperparameters we used to train UNet on the [BraTS](http://braintumorsegmentation.org) image segmentation dataset.
-
-```yaml
-optimizer:
- radam:
- lr: 0.001
- betas: [0.9, 0.999]
- eps: 0.00000001
- weight_decay: 0.0001
-schedulers:
- - constant: {}
-train_batch_size: 64
-max_duration: 200ep
-```
-
-
-## Attribution
-
-The UNet model has been introduced in "U-Net: Convolutional Networks for Biomedical Image Segmentation" by Olaf Ronneberger, Philipp Fischer, Thomas Brox in [https://arxiv.org/abs/1505.04597](https://arxiv.org/abs/1505.04597).
-
-We are using the NVDA DLE examples version in
-[https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Segmentation/nnUNet](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Segmentation/nnUNet).
-
-## API Reference
-
-```{eval-rst}
-.. autoclass:: composer.models.unet.UNet
- :noindex:
-```
diff --git a/composer/models/unet/__init__.py b/composer/models/unet/__init__.py
deleted file mode 100644
index 6f26bd4625..0000000000
--- a/composer/models/unet/__init__.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""The Unet architecture used in image segmentation. The example we are using is for BRATS medical brain tumor dataset.
-
-See the :doc:`Model Card ` for more details.
-"""
-
-from composer.models.unet.unet import UNet as UNet
-
-__all__ = ['UNet']
-
-_task = 'Image Segmentation'
-_dataset = 'BRATS'
-_name = 'UNet'
-_quality = '69.1'
-_metric = 'Dice'
-_ttt = '21m'
-_hparams = 'unet.yaml'
diff --git a/composer/models/unet/_layers.py b/composer/models/unet/_layers.py
deleted file mode 100644
index 6fae767bf5..0000000000
--- a/composer/models/unet/_layers.py
+++ /dev/null
@@ -1,138 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-## Code adapted from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Segmentation/nnUNet/
-
-import numpy as np
-import torch
-import torch.nn as nn
-
-normalizations = {
- 'instancenorm3d': nn.InstanceNorm3d,
- 'instancenorm2d': nn.InstanceNorm2d,
- 'batchnorm3d': nn.BatchNorm3d,
- 'batchnorm2d': nn.BatchNorm2d,
-}
-
-convolutions = {
- 'Conv2d': nn.Conv2d,
- 'Conv3d': nn.Conv3d,
- 'ConvTranspose2d': nn.ConvTranspose2d,
- 'ConvTranspose3d': nn.ConvTranspose3d,
-}
-
-
-def get_norm(name, out_channels):
- if 'groupnorm' in name:
- return nn.GroupNorm(32, out_channels, affine=True)
- return normalizations[name](out_channels, affine=True)
-
-
-def get_conv(in_channels, out_channels, kernel_size, stride, dim, bias=False):
- conv = convolutions[f'Conv{dim}d']
- padding = get_padding(kernel_size, stride)
- return conv(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
-
-
-def get_transp_conv(in_channels, out_channels, kernel_size, stride, dim):
- conv = convolutions[f'ConvTranspose{dim}d']
- padding = get_padding(kernel_size, stride)
- output_padding = get_output_padding(kernel_size, stride, padding)
- return conv(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=True)
-
-
-def get_padding(kernel_size, stride):
- #kernel_size_np = np.cast(np.ndarray, np.atleast_1d(kernel_size))
- #stride_np = np.cast(np.ndarray, np.atleast_1d(stride))
- kernel_size_np = np.atleast_1d(kernel_size)
- stride_np = np.atleast_1d(stride)
- padding_np = (kernel_size_np - stride_np + 1) / 2 # type: ignore
- padding = tuple(int(p) for p in padding_np) # type: ignore
- return padding if len(padding) > 1 else padding[0]
-
-
-def get_output_padding(kernel_size, stride, padding):
- kernel_size_np = np.atleast_1d(kernel_size)
- stride_np = np.atleast_1d(stride)
- padding_np = np.atleast_1d(padding)
- out_padding_np = 2 * padding_np + stride_np - kernel_size_np
- out_padding = tuple(int(p) for p in out_padding_np)
- return out_padding if len(out_padding) > 1 else out_padding[0]
-
-
-class ConvLayer(nn.Module):
-
- def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
- super(ConvLayer, self).__init__()
- self.conv = get_conv(in_channels, out_channels, kernel_size, stride, kwargs['dim'])
- self.norm = get_norm(kwargs['norm'], out_channels)
- self.lrelu = nn.LeakyReLU(negative_slope=kwargs['negative_slope'], inplace=True)
-
- def forward(self, data):
- out = self.conv(data)
- out = self.norm(out)
- out = self.lrelu(out)
- return out
-
-
-class ConvBlock(nn.Module):
-
- def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
- super(ConvBlock, self).__init__()
- self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride, **kwargs)
- self.conv2 = ConvLayer(out_channels, out_channels, kernel_size, 1, **kwargs)
-
- def forward(self, input_data):
- out = self.conv1(input_data)
- out = self.conv2(out)
- return out
-
-
-class ResidBlock(nn.Module):
-
- def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
- super(ResidBlock, self).__init__()
- self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride, **kwargs)
- self.conv2 = get_conv(out_channels, out_channels, kernel_size, 1, kwargs['dim'])
- self.norm = get_norm(kwargs['norm'], out_channels)
- self.lrelu = nn.LeakyReLU(negative_slope=kwargs['negative_slope'], inplace=True)
- self.downsample = None
- if max(stride) > 1 or in_channels != out_channels: # type: ignore
- self.downsample = get_conv(in_channels, out_channels, kernel_size, stride, kwargs['dim'])
- self.norm_res = get_norm(kwargs['norm'], out_channels)
-
- def forward(self, input_data):
- residual = input_data
- out = self.conv1(input_data)
- out = self.conv2(out)
- out = self.norm(out)
- if self.downsample is not None:
- residual = self.downsample(residual)
- residual = self.norm_res(residual)
- out = self.lrelu(out + residual)
- return out
-
-
-class UpsampleBlock(nn.Module):
-
- def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
- super(UpsampleBlock, self).__init__()
- self.transp_conv = get_transp_conv(in_channels, out_channels, stride, stride, kwargs['dim'])
- self.conv_block = ConvBlock(2 * out_channels, out_channels, kernel_size, 1, **kwargs)
-
- def forward(self, input_data, skip_data):
- out = self.transp_conv(input_data)
- out = torch.cat((out, skip_data), dim=1)
- out = self.conv_block(out)
- return out
-
-
-class OutputBlock(nn.Module):
-
- def __init__(self, in_channels, out_channels, dim):
- super(OutputBlock, self).__init__()
- self.conv = get_conv(in_channels, out_channels, kernel_size=1, stride=1, dim=dim, bias=True)
- nn.init.constant_(self.conv.bias, 0)
-
- def forward(self, input_data):
- return self.conv(input_data)
diff --git a/composer/models/unet/model.py b/composer/models/unet/model.py
deleted file mode 100644
index 08c49ff57c..0000000000
--- a/composer/models/unet/model.py
+++ /dev/null
@@ -1,123 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""The Unet architecture used in image segmentation. The example we are using is for BRATS medical brain tumor dataset.
-
-See the :doc:`Model Card ` for more details.
-"""
-
-import warnings
-
-import torch.nn as nn
-
-from composer.models.unet._layers import ConvBlock, OutputBlock, ResidBlock, UpsampleBlock
-
-__all__ = ['UNet']
-
-
-class UNet(nn.Module):
- """Unet Architecture adapted from NVidia `Deep Learning Examples`_.
-
- .. _Deep Learning Examples: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Segmentation/nnUNet/
-
- Args:
- in_channels (int): Number of input channels.
- n_class (int): Number of output layers.
- kernels (list): Conv layer kernel sizes.
- strides (list): Conv layer strides.
- normalization_layer (str): Normalization layer type, one of (``"batch"``, ``"instance"``).
- negative_slope (float): Leaky relu negative slope.
- residual (bool): Use residual connections.
- dimension (int): Filter dimensions.
- """
-
- def __init__(
- self,
- in_channels,
- n_class,
- kernels,
- strides,
- normalization_layer,
- negative_slope,
- residual,
- dimension,
- ):
- warnings.warn(DeprecationWarning('UNet is deprecated and will be removed in v0.18'))
- super(UNet, self).__init__()
- self.dim = dimension
- self.n_class = n_class
- self.residual = residual
- self.negative_slope = negative_slope
- self.norm = normalization_layer + f'norm{dimension}d'
- self.filters = [min(2**(5 + i), 320 if dimension == 3 else 512) for i in range(len(strides))]
-
- down_block = ResidBlock if self.residual else ConvBlock
- self.input_block = self.get_conv_block(
- conv_block=down_block,
- in_channels=in_channels,
- out_channels=self.filters[0],
- kernel_size=kernels[0],
- stride=strides[0],
- )
- self.downsamples = self.get_module_list(
- conv_block=down_block,
- in_channels=self.filters[:-1],
- out_channels=self.filters[1:],
- kernels=kernels[1:-1],
- strides=strides[1:-1],
- )
- self.bottleneck = self.get_conv_block(
- conv_block=down_block,
- in_channels=self.filters[-2],
- out_channels=self.filters[-1],
- kernel_size=kernels[-1],
- stride=strides[-1],
- )
- self.upsamples = self.get_module_list(
- conv_block=UpsampleBlock,
- in_channels=self.filters[1:][::-1],
- out_channels=self.filters[:-1][::-1],
- kernels=kernels[1:][::-1],
- strides=strides[1:][::-1],
- )
- self.output_block = self.get_output_block(decoder_level=0)
- self.apply(self.initialize_weights)
- self.n_layers = len(self.upsamples) - 1
-
- def forward(self, input_data):
- out = self.input_block(input_data)
- encoder_outputs = [out]
- for downsample in self.downsamples:
- out = downsample(out)
- encoder_outputs.append(out)
- out = self.bottleneck(out)
- for idx, upsample in enumerate(self.upsamples):
- out = upsample(out, encoder_outputs[self.n_layers - idx])
- out = self.output_block(out)
- return out
-
- def get_conv_block(self, conv_block, in_channels, out_channels, kernel_size, stride):
- return conv_block(
- dim=self.dim,
- stride=stride,
- norm=self.norm,
- kernel_size=kernel_size,
- in_channels=in_channels,
- out_channels=out_channels,
- negative_slope=self.negative_slope,
- )
-
- def get_output_block(self, decoder_level):
- return OutputBlock(in_channels=self.filters[decoder_level], out_channels=self.n_class, dim=self.dim)
-
- def get_module_list(self, in_channels, out_channels, kernels, strides, conv_block):
- layers = []
- for in_channel, out_channel, kernel, stride in zip(in_channels, out_channels, kernels, strides):
- conv_layer = self.get_conv_block(conv_block, in_channel, out_channel, kernel, stride)
- layers.append(conv_layer)
- return nn.ModuleList(layers)
-
- def initialize_weights(self, module):
- name = module.__class__.__name__.lower()
- if name in ['conv2d']:
- nn.init.kaiming_normal_(module.weight, a=self.negative_slope)
diff --git a/composer/models/unet/unet.py b/composer/models/unet/unet.py
deleted file mode 100644
index dde555bb4f..0000000000
--- a/composer/models/unet/unet.py
+++ /dev/null
@@ -1,110 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""A U-Net model extending :class:`.ComposerModel`."""
-
-import logging
-import warnings
-from typing import Any, Dict, Optional, Sequence, Union
-
-import torch
-import torch.nn as nn
-from torchmetrics import Metric
-
-from composer.metrics.metrics import Dice
-from composer.models.base import ComposerModel
-from composer.models.unet.model import UNet as UNetModel
-from composer.utils.import_helpers import MissingConditionalImportError
-
-log = logging.getLogger(__name__)
-
-__all__ = ['UNet']
-
-
-class UNet(ComposerModel):
- """A U-Net model extending :class:`.ComposerModel`.
-
- See U-Net: Convolutional Networks for Biomedical Image Segmentation (`Ronneberger et al, 2015`_)
- on the U-Net architecture.
-
- Args:
- num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``3``.
-
- .. _Ronneberger et al, 2015: https://arxiv.org/abs/1505.04597
- """
-
- def __init__(self, num_classes: int = 3) -> None:
- warnings.warn(DeprecationWarning('UNet is deprecated and will be removed in v0.18'))
-
- super().__init__()
- try:
- from monai.losses import DiceLoss
- except ImportError as e:
- raise MissingConditionalImportError(extra_deps_group='unet',
- conda_package='monai',
- conda_channel='conda-forge') from e
-
- self.module = self.build_nnunet()
-
- self.dice = Dice(num_classes=num_classes)
- self.dloss = DiceLoss(include_background=False, softmax=True, to_onehot_y=True, batch=True)
- self.closs = nn.CrossEntropyLoss()
-
- def loss(self, outputs: Any, batch: Any, *args, **kwargs) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
- _, y = batch
- y = y.squeeze(1) # type: ignore
- loss = self.dloss(outputs, y)
- loss += self.closs(outputs, y[:, 0].long())
- return loss
-
- @staticmethod
- def metric_mean(name, outputs):
- return torch.stack([out[name] for out in outputs]).mean(dim=0)
-
- def get_metrics(self, is_train: bool = False) -> Dict[str, Metric]:
- return {'Dice': self.dice}
-
- def forward(self, batch: Any) -> torch.Tensor:
- x, _ = batch
- x = x.squeeze(1) # type: ignore
- logits = self.module(x)
- return logits
-
- def inference2d(self, image):
- """Runs inference on a 3D image, by passing each depth slice through the model."""
- batch_modulo = image.shape[2] % 64
- if batch_modulo != 0:
- batch_pad = 64 - batch_modulo
- image = nn.ConstantPad3d((0, 0, 0, 0, batch_pad, 0), 0)(image)
-
- image = torch.transpose(image.squeeze(0), 0, 1)
- preds_shape = (image.shape[0], 4, *image.shape[2:])
- preds = torch.zeros(preds_shape, dtype=image.dtype, device=image.device)
- for start in range(0, image.shape[0] - 64 + 1, 64):
- end = start + 64
- with torch.no_grad():
- pred = self.module(image[start:end])
- preds[start:end] = pred.data
- if batch_modulo != 0:
- preds = preds[batch_pad:] # type: ignore
- return torch.transpose(preds, 0, 1).unsqueeze(0)
-
- def eval_forward(self, batch: Any, outputs: Optional[Any] = None):
- assert self.training is False, 'For validation, model must be in eval mode'
- image, _ = batch
- pred = self.inference2d(image)
- return pred
-
- def build_nnunet(self) -> torch.nn.Module:
- kernels = [[3, 3]] * 6
- strides = [[1, 1]] + [[2, 2]] * 5
- model = UNetModel(in_channels=4,
- n_class=4,
- kernels=kernels,
- strides=strides,
- dimension=2,
- residual=True,
- normalization_layer='batch',
- negative_slope=0.01)
-
- return model
diff --git a/composer/models/vit_small_patch16/__init__.py b/composer/models/vit_small_patch16/__init__.py
deleted file mode 100644
index 9992807ade..0000000000
--- a/composer/models/vit_small_patch16/__init__.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""ViT Small Patch 16 for image classification."""
-
-from composer.models.vit_small_patch16.model import vit_small_patch16 as vit_small_patch16
-
-__all__ = ['vit_small_patch16']
-
-_task = 'Image Classification'
-_dataset = 'ImageNet'
-_name = 'ViT-Small-Patch16'
-_quality = '74.52'
-_metric = 'Top-1 Accuracy'
-_ttt = '1d 59m'
-_hparams = 'vit_small_patch16.yaml'
diff --git a/composer/models/vit_small_patch16/model.py b/composer/models/vit_small_patch16/model.py
deleted file mode 100644
index dacb9db56a..0000000000
--- a/composer/models/vit_small_patch16/model.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""Implements ViT-S/16 as a :class:`.ComposerClassifier`."""
-
-import warnings
-
-from composer.models.tasks import ComposerClassifier
-
-__all__ = ['vit_small_patch16']
-
-
-def vit_small_patch16(num_classes: int = 1000,
- image_size: int = 224,
- channels: int = 3,
- dropout: float = 0.0,
- embedding_dropout: float = 0.0):
- """Helper function to create a :class:`.ComposerClassifier` using a ViT-S/16 model.
-
- See `Training data-efficient image transformers & distillation through attention `_
- (Touvron et al, 2021) for details on ViT-S/16.
-
- Args:
- num_classes (int, optional): number of classes for the model. Default: ``1000``.
- image_size (int, optional): input image size. If you have rectangular images, make sure your image
- size is the maximum of the width and height. Default: ``224``.
- channels (int, optional): number of image channels. Default: ``3``.
- dropout (float, optional): 0.0 - 1.0 dropout rate. Default: ``0``.
- embedding_dropout (float, optional): 0.0 - 1.0 embedding dropout rate. Default: ``0``.
-
- Returns:
- ComposerModel: instance of :class:`.ComposerClassifier` with a ViT-S/16 model.
- """
- warnings.warn(DeprecationWarning('vit_small_patch16 is deprecated and will be removed in v0.18'))
-
- from vit_pytorch import ViT
- model = ViT(
- image_size=image_size,
- channels=channels,
- num_classes=num_classes,
- dim=384, # embed dim/width
- patch_size=16,
- depth=12, # layers
- heads=6,
- mlp_dim=1536,
- dropout=dropout,
- emb_dropout=embedding_dropout)
-
- composer_model = ComposerClassifier(module=model)
- return composer_model
diff --git a/composer/optim/decoupled_weight_decay.py b/composer/optim/decoupled_weight_decay.py
index 35a235cce7..2d20aad286 100644
--- a/composer/optim/decoupled_weight_decay.py
+++ b/composer/optim/decoupled_weight_decay.py
@@ -11,7 +11,7 @@
import logging
import math
-from typing import Iterable, List, Tuple, Union
+from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch.optim import SGD, AdamW
@@ -70,8 +70,9 @@ def __init__(
group['initial_lr'] = group['lr']
@staticmethod
- def sgdw(params: List[torch.Tensor], d_p_list: List[torch.Tensor], momentum_buffer_list: List[torch.Tensor], *,
- weight_decay: float, momentum: float, lr: float, initial_lr: float, dampening: float, nesterov: bool):
+ def sgdw(params: List[torch.Tensor], d_p_list: List[torch.Tensor],
+ momentum_buffer_list: List[Optional[torch.Tensor]], *, weight_decay: float, momentum: float, lr: float,
+ initial_lr: float, dampening: float, nesterov: bool):
r"""Functional API that performs SGDW algorithm computation.
Args:
@@ -109,7 +110,7 @@ def sgdw(params: List[torch.Tensor], d_p_list: List[torch.Tensor], momentum_buff
param.add_(d_p, alpha=-lr)
- @torch.no_grad()
+ @torch.no_grad() # pyright: ignore[reportUntypedFunctionDecorator]
def step(self, closure=None):
"""Performs a single optimization step.
@@ -263,7 +264,7 @@ def adamw(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[
param.addcdiv_(exp_avg, denom, value=-step_size)
- @torch.no_grad()
+ @torch.no_grad() # pyright: ignore[reportUntypedFunctionDecorator]
def step(self, closure=None):
"""Performs a single optimization step.
diff --git a/composer/optim/scheduler.py b/composer/optim/scheduler.py
index 294d26ddb4..d8c6c82c6d 100644
--- a/composer/optim/scheduler.py
+++ b/composer/optim/scheduler.py
@@ -18,9 +18,9 @@
import warnings
from typing import TYPE_CHECKING, List, Union
-from torch.optim.lr_scheduler import LambdaLR
+from torch.optim.lr_scheduler import LambdaLR, LRScheduler
-from composer.core import PyTorchScheduler, State, Time, TimeUnit
+from composer.core import State, Time, TimeUnit
if TYPE_CHECKING:
from typing import Protocol
@@ -31,10 +31,21 @@
log = logging.getLogger(__name__)
__all__ = [
- 'ComposerScheduler', 'compile_composer_scheduler', 'StepScheduler', 'MultiStepScheduler', 'ConstantScheduler',
- 'LinearScheduler', 'ExponentialScheduler', 'CosineAnnealingScheduler', 'CosineAnnealingWarmRestartsScheduler',
- 'PolynomialScheduler', 'MultiStepWithWarmupScheduler', 'ConstantWithWarmupScheduler', 'LinearWithWarmupScheduler',
- 'CosineAnnealingWithWarmupScheduler', 'PolynomialWithWarmupScheduler'
+ 'ComposerScheduler',
+ 'compile_composer_scheduler',
+ 'StepScheduler',
+ 'MultiStepScheduler',
+ 'ConstantScheduler',
+ 'LinearScheduler',
+ 'ExponentialScheduler',
+ 'CosineAnnealingScheduler',
+ 'CosineAnnealingWarmRestartsScheduler',
+ 'PolynomialScheduler',
+ 'MultiStepWithWarmupScheduler',
+ 'ConstantWithWarmupScheduler',
+ 'LinearWithWarmupScheduler',
+ 'CosineAnnealingWithWarmupScheduler',
+ 'PolynomialWithWarmupScheduler',
]
@@ -147,7 +158,7 @@ def _convert_time(time: Union[str, Time[int], Time[float]], state: State, ssr: f
return Time(value=int(time.value * ssr), unit=time.unit)
-def compile_composer_scheduler(scheduler: ComposerScheduler, state: State, ssr: float = 1.0) -> PyTorchScheduler:
+def compile_composer_scheduler(scheduler: ComposerScheduler, state: State, ssr: float = 1.0) -> LRScheduler:
"""Converts a stateless scheduler into a PyTorch scheduler object.
While the resulting scheduler provides a ``.step()`` interface similar to other PyTorch schedulers, the scheduler is
@@ -160,7 +171,7 @@ def compile_composer_scheduler(scheduler: ComposerScheduler, state: State, ssr:
state (State): The Composer Trainer's state.
Returns:
- compiled_scheduler (PyTorchScheduler): The scheduler, in a form compatible with PyTorch scheduler interfaces.
+ compiled_scheduler (LRScheduler): The scheduler, in a form compatible with PyTorch scheduler interfaces.
"""
optimizers = state.optimizers
if len(optimizers) != 1:
diff --git a/composer/profiler/profiler.py b/composer/profiler/profiler.py
index c88c1f0912..a3a7127e58 100644
--- a/composer/profiler/profiler.py
+++ b/composer/profiler/profiler.py
@@ -9,6 +9,8 @@
import pathlib
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Tuple, Union
+from composer.core import Callback
+from composer.loggers import Logger
from composer.profiler.json_trace_handler import JSONTraceHandler
from composer.profiler.marker import Marker
from composer.profiler.profiler_action import ProfilerAction
@@ -18,14 +20,14 @@
from composer.utils import ensure_tuple, parse_uri
if TYPE_CHECKING:
- from composer.core import Callback, State
+ from composer.core import State
__all__ = ['Profiler']
log = logging.getLogger(__name__)
-class Profiler:
+class Profiler(Callback):
"""Composer Profiler.
See the :doc:`Profiling Guide ` for additional information.
@@ -118,6 +120,8 @@ def __init__(
self.schedule = schedule
self.state = None
self._callbacks: List[Callback] = []
+ # Used to count skip_first starting from resumption timestamp
+ self.resumption_batch_idx: int = 0
self.remote_filenames: List[str] = []
# First, add each remote file name to self.remote_filenames to create RemoteUploaderDownloader logger in trainer. [s3://bucket/path/to/file]
# Then modify remote file name to be a local path to pass into torch_profiler and system_profiler. e.g: path/to/file
@@ -185,6 +189,7 @@ def bind_to_state(
state (State): The training state.
"""
self.state = state
+ self.state.callbacks.append(self)
self.state.callbacks.extend(self._callbacks)
self.state.callbacks.extend(self._trace_handlers)
@@ -289,3 +294,7 @@ def should_record(state: State) -> bool:
)
self._names_to_markers[name].categories = categories
return self._names_to_markers[name]
+
+ def after_load(self, state: State, logger: Logger) -> None:
+ del logger
+ self.resumption_batch_idx = int(state.timestamp.batch_in_epoch)
diff --git a/composer/profiler/profiler_schedule.py b/composer/profiler/profiler_schedule.py
index 02b72b8a50..08d2549c2b 100644
--- a/composer/profiler/profiler_schedule.py
+++ b/composer/profiler/profiler_schedule.py
@@ -23,10 +23,11 @@ def cyclic_schedule(
This function returns a schedule function that uses a cyclic profiling window. The resulting function can be
passed as the ``prof_schedule`` argument to the :class:`.Trainer`.
- The cyclic window skips the first ``skip_first`` batches in every epoch. Then, it performs a cycle of
- skipping ``wait`` batches, warming up for ``warmup`` batches, and recording ``active`` batches.
- It repeats this cycle up to ``repeat`` times per epoch (or for the entire epoch, if ``repeat`` is 0).
- This logic repeats every epoch.
+ The cyclic window skips the first ``skip_first`` + ``resumption_batch_idx`` batches in every epoch.
+ ``resumption_batch_idx`` is accessed from state.profiler. It is the ``state.timestamp.batch_in_epoch``
+ when resuming training. Then, it performs a cycle of skipping ``wait`` batches, warming up for ``warmup``
+ batches, and recording ``active`` batches. It repeats this cycle up to ``repeat`` times per epoch (or
+ for the entire epoch, if ``repeat`` is 0). This logic repeats every epoch.
Args:
skip_first (int, optional): Number of batches to skip profiling at epoch start. Defaults to ``0``.
@@ -46,12 +47,16 @@ def schedule(state: State):
# do wait, then warump, then active, up to repeat times per cycle
cycle_len = wait + warmup + active
batch_idx = int(state.timestamp.batch_in_epoch)
- if batch_idx < skip_first:
+ if state.profiler is not None:
+ skip_first_after_resumption = skip_first + state.profiler.resumption_batch_idx
+ else:
+ skip_first_after_resumption = skip_first
+ if batch_idx < skip_first_after_resumption:
return ProfilerAction.SKIP
- if repeat != 0 and batch_idx >= cycle_len * repeat + skip_first:
+ if repeat != 0 and batch_idx >= cycle_len * repeat + skip_first_after_resumption:
# exhausted the repeat
return ProfilerAction.SKIP
- position_in_cycle = (batch_idx - skip_first) % cycle_len
+ position_in_cycle = (batch_idx - skip_first_after_resumption) % cycle_len
if position_in_cycle < wait:
return ProfilerAction.SKIP
if position_in_cycle < wait + warmup:
diff --git a/composer/profiler/system_profiler.py b/composer/profiler/system_profiler.py
index 3f9c928c23..3bc19fb3ca 100644
--- a/composer/profiler/system_profiler.py
+++ b/composer/profiler/system_profiler.py
@@ -98,7 +98,9 @@ def _stats_thread(self, profiler: Profiler):
})
if self.profile_disk:
- disk_io_counters = cast(Dict[str, psutil._common.sdiskio], psutil.disk_io_counters(perdisk=True))
+ disk_io_counters = cast(
+ Dict[str, psutil._common.sdiskio], # type: ignore
+ psutil.disk_io_counters(perdisk=True))
for disk_name, disk_stats in disk_io_counters.items():
for field_name in ('read_count', 'write_count', 'read_bytes', 'write_bytes', 'read_time',
'write_time', 'busy_time'):
@@ -106,7 +108,9 @@ def _stats_thread(self, profiler: Profiler):
categories=['disk']).counter({'field_name': getattr(disk_stats, field_name)})
if self.profile_net:
- net_io_counters = cast(Dict[str, psutil._common.snetio], psutil.net_io_counters(pernic=True))
+ net_io_counters = cast(
+ Dict[str, psutil._common.snetio], # type: ignore
+ psutil.net_io_counters(pernic=True))
for nic, nic_stats in net_io_counters.items():
profiler.marker(f'network/{nic}/kb_sent',
categories=['net']).counter({'kb_sent': nic_stats.bytes_sent / 2**3})
diff --git a/composer/profiler/torch_profiler.py b/composer/profiler/torch_profiler.py
index a8c51bb27b..0f8f1f4fb0 100644
--- a/composer/profiler/torch_profiler.py
+++ b/composer/profiler/torch_profiler.py
@@ -259,24 +259,23 @@ def handler_fn(prof: torch.profiler.profiler.profile):
timestamp = state.timestamp
log.info(f'PyTorch Chrome trace profiler enabled: {self.filename if self.filename else False}')
- if self.filename is not None:
- trace_file_name = os.path.join(
- folder_name,
- format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=timestamp),
- )
- trace_file_dirname = os.path.dirname(trace_file_name)
- if trace_file_dirname:
- os.makedirs(trace_file_dirname, exist_ok=True)
- prof.export_chrome_trace(trace_file_name)
- state.profiler.record_chrome_json_trace_file(trace_file_name)
- if self.remote_file_name is not None:
- trace_remote_file_name = format_name_with_dist_and_time(self.remote_file_name,
- run_name=state.run_name,
- timestamp=timestamp)
- trace_remote_file_name = trace_remote_file_name.lstrip('/')
- logger.upload_file(remote_file_name=trace_remote_file_name,
- file_path=trace_file_name,
- overwrite=self.overwrite)
+ trace_file_name = os.path.join(
+ folder_name,
+ format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=timestamp),
+ )
+ trace_file_dirname = os.path.dirname(trace_file_name)
+ if trace_file_dirname:
+ os.makedirs(trace_file_dirname, exist_ok=True)
+ prof.export_chrome_trace(trace_file_name)
+ state.profiler.record_chrome_json_trace_file(trace_file_name)
+ if self.remote_file_name is not None:
+ trace_remote_file_name = format_name_with_dist_and_time(self.remote_file_name,
+ run_name=state.run_name,
+ timestamp=timestamp)
+ trace_remote_file_name = trace_remote_file_name.lstrip('/')
+ logger.upload_file(remote_file_name=trace_remote_file_name,
+ file_path=trace_file_name,
+ overwrite=self.overwrite)
log.info(
f'PyTorch memory timeline profiler enabled: {self.memory_filename if self.memory_filename else False}')
diff --git a/composer/trainer/_deepspeed.py b/composer/trainer/_deepspeed.py
index a3ef6e0ef2..0217770a23 100644
--- a/composer/trainer/_deepspeed.py
+++ b/composer/trainer/_deepspeed.py
@@ -161,7 +161,7 @@ def _fix_batch_precision_for_deepspeed(batch: Batch, precision: Precision) -> Ba
Batch: The batch with it's precision adjusted to the specified precision.
"""
if precision == Precision.AMP_FP16:
- return map_collection(batch, _convert_fp32_tensor_to_fp16) # type: ignore
+ return map_collection(batch, _convert_fp32_tensor_to_fp16)
elif precision == Precision.AMP_BF16:
- return map_collection(batch, _convert_fp32_tensor_to_bf16) # type: ignore
+ return map_collection(batch, _convert_fp32_tensor_to_bf16)
return batch
diff --git a/composer/trainer/_scale_schedule.py b/composer/trainer/_scale_schedule.py
index 5cdb37da60..cc94caf7c4 100644
--- a/composer/trainer/_scale_schedule.py
+++ b/composer/trainer/_scale_schedule.py
@@ -3,12 +3,11 @@
from collections import Counter
-from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts, ExponentialLR, MultiStepLR, StepLR
+from torch.optim.lr_scheduler import (CosineAnnealingLR, CosineAnnealingWarmRestarts, ExponentialLR, LRScheduler,
+ MultiStepLR, StepLR)
-from composer.core import PyTorchScheduler
-
-def scale_pytorch_scheduler(scheduler: PyTorchScheduler, ssr: float):
+def scale_pytorch_scheduler(scheduler: LRScheduler, ssr: float):
"""Makes a learning rate schedule take a different number of epochs.
Training for less time is a strong baseline approach to speeding up
diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py
index f2c8c615b4..8b76f8b1ba 100644
--- a/composer/trainer/dist_strategy.py
+++ b/composer/trainer/dist_strategy.py
@@ -11,6 +11,10 @@
import torch
from packaging import version
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (CheckpointImpl, apply_activation_checkpointing,
+ checkpoint_wrapper)
+from torch.distributed.fsdp import FullyShardedDataParallel
+from torch.distributed.fsdp._common_utils import clean_tensor_name
from torch.nn.parallel import DistributedDataParallel
from torchmetrics import Metric, MetricCollection
@@ -20,7 +24,7 @@
from composer.trainer.mosaic_fsdp import patch_pytorch
from composer.trainer.mosaic_fsdp_utils import (BACKWARD_PREFETCH_MAP, SHARDING_MAP, _set_custom_fsdp_module_kwargs,
get_cpu_offload, get_mixed_precision)
-from composer.utils import StringEnum, dist, ensure_tuple, using_torch_2
+from composer.utils import StringEnum, dist, ensure_tuple
__all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module']
@@ -178,13 +182,7 @@ def _recreate_fsdp_param_groups_from_unwrapped_opt_info(
Returns a list of param groups, referencing the fsdp parameters
"""
- is_torch_2_0 = using_torch_2()
- if not is_torch_2_0:
- raise RuntimeError('Helper function is only supported in torch 2.0')
-
- from torch.distributed.fsdp._common_utils import clean_tensor_name
-
- # initialize an empty list of parameters for each optimizer group
+ # Initialize an empty list of parameters for each optimizer group
for group_num in group_num_to_optimizer_info.keys():
group_num_to_optimizer_info[group_num]['params'] = []
@@ -217,16 +215,6 @@ def prepare_fsdp_module(
device (Device): The device being used by the Trainer.
auto_microbatching (bool, optional): Whether or not auto microbatching is enabled.
"""
- if version.parse(torch.__version__) < version.parse('1.13.0'):
- raise RuntimeError('To use FSDP with Composer, you must use torch>=1.13.0.')
- is_torch_2_0 = using_torch_2()
- from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (CheckpointImpl,
- apply_activation_checkpointing,
- checkpoint_wrapper)
- from torch.distributed.fsdp import FullyShardedDataParallel
- if not is_torch_2_0:
- from torch.distributed.fsdp.flatten_params_wrapper import FlattenParamsWrapper
-
patch_pytorch()
set_fsdp_default(fsdp_config)
@@ -243,10 +231,6 @@ def prepare_fsdp_module(
'gpu and some ranks are on meta. Either keep all ranks on the same '
"device or set fsdp_config['sync_module_states'] = True. Otherwise, "
'some weights may be randomly initialized when loading a checkpoint.')
- if fsdp_config['sharding_strategy'] in ('HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'):
- raise ValueError('HSDP (HYBRID_SHARD or _HYBRID_SHARD_ZERO2) requires '
- 'fsdp_config["sync_module_states"] = True or different replicas will '
- 'have different weights.')
# Check if other ranks OOMed after forward/backward pass when using auto microbatching. This
# may happen when close to memory limit or with uneven memory usage across ranks. Since we
@@ -265,14 +249,13 @@ def sync_hook(*args):
raise RuntimeError('CUDA out of memory encountered on a different rank')
kwargs = {}
- if is_torch_2_0:
- # Support of new parameter `use_orig_params` in PyTorch 2.0 or higher.
- # Setting this to `True` has FSDP use `module`'s original parameters via method
- # `nn.Module.named_parameters` instead of FSDP's internal class `FlatParameter`. However,
- # setting it to `False` exposes FSDP's internal class `FlatParameter` via method
- # `nn.Module.named_parameters`.
- # Setting it to `True` is mandatory when using `torch.compile()`.
- kwargs['use_orig_params'] = fsdp_config['use_orig_params']
+ if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0'):
+ if 'device_mesh' in fsdp_config:
+ from torch.distributed._tensor import init_device_mesh
+ kwargs['device_mesh'] = init_device_mesh(
+ 'cuda',
+ tuple([int(x) for x in fsdp_config['device_mesh']]),
+ )
# necessary variables for optimizers with multiple param groups in FSDP
num_param_groups = None
@@ -291,9 +274,9 @@ def sync_hook(*args):
num_param_groups = len(optim.param_groups)
if num_param_groups > 1:
- if not (is_torch_2_0 and kwargs['use_orig_params']):
- raise RuntimeError('Multiple optimizer groups with FSDP are only supported on torch 2.0 \
- with use_orig_params=True.')
+ if not fsdp_config['use_orig_params']:
+ raise RuntimeError('Multiple optimizer groups with FSDP are only supported with '
+ 'use_orig_params=True.')
# optimizer.param_groups do not contain parameter names which are needed
# to keep track of the different parameters in each group
# so we use the pointers between model.parameters() and model.named_parameters()
@@ -367,6 +350,7 @@ def sync_hook(*args):
state_dict_type = fsdp_config['state_dict_type']
activation_checkpointing_reentrant = fsdp_config['activation_checkpointing_reentrant']
sharded_ckpt_prefix_dir = fsdp_config['sharded_ckpt_prefix_dir']
+ use_orig_params = fsdp_config['use_orig_params']
# We choose to not wrap the ComposerModel directly, but instead wrap any submodules like `ComposerModel.model`
# This makes it safer to call ComposerModel-specific functions like 'eval_forward' that
@@ -546,19 +530,10 @@ def __auto_wrap_policy(module: torch.nn.Module, recurse: bool, nonwrapped_numel:
module.register_full_backward_hook(sync_hook)
return should_be_wrapped
- if is_torch_2_0:
-
- def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool:
- return __auto_wrap_policy(module, recurse, nonwrapped_numel)
-
- _auto_wrap_policy = _auto_wrap_policy_new
+ def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool:
+ return __auto_wrap_policy(module, recurse, nonwrapped_numel)
- else:
-
- def _auto_wrap_policy_old(module: torch.nn.Module, recurse: bool, unwrapped_params: int) -> bool:
- return __auto_wrap_policy(module, recurse, unwrapped_params)
-
- _auto_wrap_policy = _auto_wrap_policy_old
+ _auto_wrap_policy = _auto_wrap_policy_new
fsdp_obj = FullyShardedDataParallel(
obj,
@@ -574,6 +549,7 @@ def _auto_wrap_policy_old(module: torch.nn.Module, recurse: bool, unwrapped_para
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
+ use_orig_params=use_orig_params,
**kwargs,
)
@@ -636,8 +612,6 @@ def _auto_wrap_policy_old(module: torch.nn.Module, recurse: bool, unwrapped_para
# If module has attribute `module._activation_checkpointing = ...`, always respect it
# Otherwise checkpoint if root object `obj.activation_checkpointing_fn(module)` is true
def _check_fn(module: torch.nn.Module) -> bool:
- if not is_torch_2_0 and isinstance(module, FlattenParamsWrapper):
- return False
if isinstance(module, FullyShardedDataParallel):
return False
if hasattr(module, '_activation_checkpointing'):
@@ -657,24 +631,22 @@ def _check_fn(module: torch.nn.Module) -> bool:
# Print FSDP wrapped model and FSDP config if `verbose=True`
if fsdp_config['verbose']:
- print(f'FSDP: Wrapped Model:')
- print(model)
- print(f'FSDP: Using sharding_strategy={sharding_strategy}')
- print(f'FSDP: Using cpu_offload={cpu_offload}')
- print(f'FSDP: Using mixed_precision={mixed_precision}')
- print(f'FSDP: Using backward_prefetch={backward_prefetch}')
- print(f'FSDP: Using activation_checkpointing={activation_checkpointing}')
- print(f'FSDP: Using activation_cpu_offload={activation_cpu_offload}')
- print(f'FSDP: Using sync_module_states={sync_module_states}')
- print(f'FSDP: Using forward_prefetch={forward_prefetch}')
- print(f'FSDP: Using limit_all_gathers={limit_all_gathers}')
- print(f'FSDP: Using state_dict_type={state_dict_type}')
- print(f'FSDP: Using sharded_ckpt_prefix_dir={sharded_ckpt_prefix_dir}')
+ log.info(f'FSDP: Wrapped model: {model}')
+ log.info(f'FSDP: Using sharding_strategy={sharding_strategy}')
+ log.info(f'FSDP: Using cpu_offload={cpu_offload}')
+ log.info(f'FSDP: Using mixed_precision={mixed_precision}')
+ log.info(f'FSDP: Using backward_prefetch={backward_prefetch}')
+ log.info(f'FSDP: Using activation_checkpointing={activation_checkpointing}')
+ log.info(f'FSDP: Using activation_cpu_offload={activation_cpu_offload}')
+ log.info(f'FSDP: Using sync_module_states={sync_module_states}')
+ log.info(f'FSDP: Using forward_prefetch={forward_prefetch}')
+ log.info(f'FSDP: Using limit_all_gathers={limit_all_gathers}')
+ log.info(f'FSDP: Using state_dict_type={state_dict_type}')
+ log.info(f'FSDP: Using sharded_ckpt_prefix_dir={sharded_ckpt_prefix_dir}')
# Rebuild optimizer now that parameters are sharded
if optimizers:
- optimizers_tuple = ensure_tuple(optimizers)
- optim = optimizers_tuple[0]
+ optim = ensure_tuple(optimizers)[0]
optim.param_groups.clear()
assert num_param_groups is not None
diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py
index ad0fd0904c..07a4f15fbf 100644
--- a/composer/trainer/mosaic_fsdp.py
+++ b/composer/trainer/mosaic_fsdp.py
@@ -6,30 +6,16 @@
"""Monkey patch FSDPs _auto_wrap to enable module_kwargs and custom process_group cache and ChunkShardingSpec to enable sharding over all gpus."""
+# pyright: reportGeneralTypeIssues=false
import torch
from packaging import version
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed.fsdp import FullyShardedDataParallel
-from composer.trainer.mosaic_fsdp_utils import (_sharded_pre_load_state_dict_hook, build_metadata,
- custom_auto_wrap_t1p13p1)
-
def patch_pytorch():
"""Monkey patches pytorch functions based on pytorch version."""
- if version.parse(torch.__version__) < version.parse('1.13.1'):
- raise NotImplementedError(f'Not supported for torch < 1.13.1')
-
- elif version.parse(torch.__version__) < version.parse('2.0.0'):
- # Monkey patch for torch < 2.0 ie torch == 1.13.1
-
- # Monkey patch _auto_wrap with _custom_auto_wrap fn
- FullyShardedDataParallel._auto_wrap = custom_auto_wrap_t1p13p1 # type: ignore
-
- elif version.parse(torch.__version__) < version.parse('2.0.1'):
- raise NotImplementedError(f'Not supported for torch == 2.0.0')
-
- elif version.parse(torch.__version__) < version.parse('2.0.2'):
+ if version.parse(torch.__version__) < version.parse('2.0.2'):
# Monkey patch for torch == 2.0.1
# Monkey patch __init__ where __init__ calls the custom _auto_wrap fn
@@ -38,16 +24,23 @@ def patch_pytorch():
FullyShardedDataParallel.__init__ = init_fn_t2p0p1 # type: ignore
# Monkey patch sharding method
+ from composer.trainer.mosaic_fsdp_utils import build_metadata
+
ChunkShardingSpec.build_metadata = build_metadata
elif version.parse(torch.__version__) < version.parse('2.1.1'):
# Monkey patch for torch < 2.1.1 ie torch == 2.1.0
# Monkey patch sharding method
+ from composer.trainer.mosaic_fsdp_utils import build_metadata
+
ChunkShardingSpec.build_metadata = build_metadata
# Monkey patch partial state dict handling
from torch.distributed.fsdp import _state_dict_utils
+
+ from composer.trainer.mosaic_fsdp_utils import _sharded_pre_load_state_dict_hook
+
_state_dict_utils._sharded_pre_load_state_dict_hook = (_sharded_pre_load_state_dict_hook)
# Allow 2D HSDP
@@ -61,23 +54,34 @@ def patch_pytorch():
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None
- # Better overlap communication and computation
- from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p1,
- _wait_for_computation_stream, forward)
- _runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p1
- _runtime_utils._wait_for_computation_stream = _wait_for_computation_stream
- _runtime_utils._root_pre_forward = _root_pre_forward
- FullyShardedDataParallel.forward = forward
-
elif version.parse(torch.__version__) < version.parse('2.2.1'):
# Monkey patch for torch < 2.2.1 ie torch == 2.2.0
- # Better overlap communication and computation
+ # Allow 2D HSDP
from torch.distributed.fsdp import _runtime_utils
+ _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None
+
+ elif version.parse(torch.__version__) < version.parse('2.3.1'):
+ # Monkey patch for torch < 2.3.1 ie torch == 2.3.0
+ # Note: this is the same patch as 2.2.0, we are just making a new if branch
+ # for clarity and modularity of changes.
+
+ # Allow 2D HSDP
+ from torch.distributed.fsdp import _runtime_utils
+ _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None
+
+ # Monkeypatch state_dict
+ from composer.trainer.mosaic_fsdp_utils import init_fn_t2p3p0
+ FullyShardedDataParallel.__init__ = init_fn_t2p3p0
+
+ # Monkeypatch state_dict
+ from torch.distributed.checkpoint import state_dict # type: ignore
+
+ from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p3p0
+ state_dict._verify_options = _verify_options_t2p3p0
+
+ # Monkeypatch sharding optim state
+ from torch.distributed.fsdp import _optim_utils
- from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p2,
- _wait_for_computation_stream, forward)
- _runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p2
- _runtime_utils._wait_for_computation_stream = _wait_for_computation_stream
- _runtime_utils._root_pre_forward = _root_pre_forward
- FullyShardedDataParallel.forward = forward
+ from composer.trainer.mosaic_fsdp_utils import _shard_orig_param_state
+ _optim_utils._shard_orig_param_state = _shard_orig_param_state
diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py
index 3cf26d79ec..5b08f4c35f 100644
--- a/composer/trainer/mosaic_fsdp_utils.py
+++ b/composer/trainer/mosaic_fsdp_utils.py
@@ -4,13 +4,18 @@
# Released under BSD 3-Clause License,
# Copyright (c) Facebook, Inc. and its affiliates.
+# yapf: disable
+# isort: skip_file
+
"""Utilities for monkey patching FSDP."""
import functools
import logging
import math
import warnings
-from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union, cast, no_type_check
+import contextlib
+from dataclasses import asdict
+from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union, cast, no_type_check
import torch
import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
@@ -35,6 +40,7 @@
torch.__version__) < version.parse('2.2.0'):
from torch.distributed.fsdp._common_utils import _FSDPState
+
log = logging.getLogger(__name__)
SHARDING_MAP = {
@@ -206,153 +212,6 @@ def _set_custom_fsdp_module_kwargs(module_kwargs: Dict, process_group_cache: Dic
return module_kwargs
-
-def _custom_recursive_wrap_t1p13p1(
- module: nn.Module,
- auto_wrap_policy: Callable,
- wrapper_cls: Callable,
- ignored_modules: Set[nn.Module],
- ignored_params: Set[nn.Parameter],
- process_group_cache: Dict[Tuple[int], Any],
- only_wrap_children: bool = False,
- **kwargs: Any,
-) -> Tuple[nn.Module, int]:
- """Updates FSDPs _recursive_wrap to enable module_kwargs and custom process_group cache.
-
- torch version must be 1.13.1.
-
- modified version of
- https://github.com/pytorch/pytorch/blob/d922c29a22e4bf0fba49526f7536395eb8cd66f4/torch/distributed/fsdp/wrap.py#L353
- which recursively wraps modules as FSDP modules for parameter sharding.
- This modification enables the user to pass custom FSDP arguements for every wrapped module.
- The added process_group_cache enables different FSDP modules to, when appropriate, use the
- same process group instead of instantiating a new process group.
-
- Automatically wrap child modules of *module* that meet the given
- criteria with :func:`auto_wrap`. Does not rely on _ConfigAutoWrap.
-
- Args:
- module (nn.Module):
- module to recursively wrap
- auto_wrap_policy (Callable):
- A callable specifying a policy to recursively wrap layers with FSDP.
- ignored_modules (Set[torch.nn.Module]): Modules to ignore when
- wrapping.
- ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when
- wrapping; these should be the parameters contained in the modules
- in ``ignored_modules``.
- process_group_cache (Dict[Tuple[int], Any]): a cache of process_group to
- use instead of potentially instantiating a new process_group
-
- Returns:
- (nn.Module, int):
- Wrapped module and the number parameters wrapped recursively.
- """
- from torch.distributed.fsdp.wrap import _wrap
-
- assert auto_wrap_policy is not None, 'Must specify auto_wrap_policy.'
- assert wrapper_cls is not None, 'Must specify wrapper_cls'
- # Make sure no child is already wrapped.
- for _, child in module.named_modules():
- if child in ignored_modules:
- continue
- try:
- assert not isinstance(child, cast(type, wrapper_cls))
- except TypeError:
- # wrapper_cls is a function as opposed to a class type, just bypass above check.
- pass
-
- # We count all params, assuming none of them are already wrapped.
- num_params = sum(p.numel() for p in module.parameters() if p not in ignored_params)
-
- assert auto_wrap_policy is not None
- if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params):
- total_wrapped_params = 0
- # Iterate through the children, recursively wrap if necessary
- for name, child in module.named_children():
- if child in ignored_modules:
- continue
- wrapped_child, num_wrapped_params = _custom_recursive_wrap_t1p13p1(
- module=child,
- auto_wrap_policy=auto_wrap_policy,
- wrapper_cls=wrapper_cls,
- ignored_modules=ignored_modules,
- ignored_params=ignored_params,
- process_group_cache=process_group_cache,
- **kwargs,
- )
- setattr(module, name, wrapped_child)
- # Keep track of how many parameters have been wrapped
- total_wrapped_params += num_wrapped_params
- # decide if we need to wrap the current module,
- # since the left over parameters exceed the number of params to wrap
- remainder = num_params - total_wrapped_params
- module_kwargs = auto_wrap_policy(module=module, recurse=False, unwrapped_params=remainder)
- if not only_wrap_children and module_kwargs:
- # CHANGE: We modify the original code to support custom FSDP kwargs and add
- # the process_group_cache to avoid instantiating a new process group.
- module_kwargs = module_kwargs if isinstance(module_kwargs, dict) else {}
- module_kwargs = _set_custom_fsdp_module_kwargs(module_kwargs, process_group_cache)
-
- final_kwargs = {**kwargs, **module_kwargs}
-
- # Leaf node or final wrapping of the remainder both happen here.
- return _wrap(module, wrapper_cls, **final_kwargs), num_params
- else:
- return module, total_wrapped_params
- return module, 0
-
-
-def custom_auto_wrap_t1p13p1(
- self,
- auto_wrap_kwargs: Dict[str, Any],
- fsdp_kwargs: Dict[str, Any],
-) -> None:
- """Updates _auto_wrap to enable module_kwargs.
-
- torch version must be 1.13.1.
-
- modified version of
- https://github.com/pytorch/pytorch/blob/d922c29a22e4bf0fba49526f7536395eb8cd66f4/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1252
- FSDP's _auto_wrap recursively wraps modules as FSDP modules for parameter sharding.
- This modification enables the user to pass custom FSDP arguements for every wrapped module.
- The added process_group_cache enables different FSDP modules to, when appropriate, use the
- same process group instead of instantiating a new process group.
-
- Recursively auto wraps the root module given by the key "module" in
- ``auto_wrap_kwargs`` with the arguments in ``auto_wrap_kwargs`` and
- ``fsdp_kwargs``.
- Precondition: ``auto_wrap_policy`` contains the arguments expected by
- ``_recursive_wrap()``, where ``auto_wrap_policy`` is not ``None``.
- ``fsdp_kwargs`` contains all FSDP arguments except ``module``.
- """
- from torch.distributed.fsdp._utils import _contains_batchnorm, _override_batchnorm_mixed_precision
- from torch.distributed.fsdp.wrap import _or_policy, _wrap_batchnorm_individually
-
- auto_wrap_policy = auto_wrap_kwargs['auto_wrap_policy']
- root_module = auto_wrap_kwargs['module']
- assert auto_wrap_policy is not None
- # For auto wrapping, submodules should not already be wrapped with FSDP
- # since double wrapping is not supported
- for module_name, module in root_module.named_modules():
- if isinstance(module, FullyShardedDataParallel):
- raise ValueError(f'Expected {module_name} to NOT be FullyShardedDataParallel '
- 'if using an `auto_wrap_policy`')
- mixed_precision = fsdp_kwargs['mixed_precision']
- if mixed_precision is not None and _contains_batchnorm(root_module):
- _override_batchnorm_mixed_precision(root_module)
- auto_wrap_policy = functools.partial(_or_policy, policies=[_wrap_batchnorm_individually, auto_wrap_policy])
- warnings.warn('Both mixed precision and an `auto_wrap_policy` were specified '
- 'for FSDP, where the wrapped module has batch norm submodules. '
- 'The batch norm submodules will be wrapped as separate FSDP '
- 'instances with mixed precision disabled since some batch norm '
- 'kernels do not support low precision.')
- auto_wrap_kwargs['auto_wrap_policy'] = auto_wrap_policy
- # CHANGE: Add process group cache and call our custom _recursive_wrap
- auto_wrap_kwargs['process_group_cache'] = {}
- _custom_recursive_wrap_t1p13p1(**auto_wrap_kwargs, **fsdp_kwargs)
-
-
def _custom_recursive_wrap_t2p0p1(
module: nn.Module,
auto_wrap_policy: Callable,
@@ -370,7 +229,7 @@ def _custom_recursive_wrap_t2p0p1(
modified version of
https://github.com/pytorch/pytorch/blob/96ca226a7332be0d8f3d6159d0c797e032ab0721/torch/distributed/fsdp/wrap.py#L320
which recursively wraps modules as FSDP modules for parameter sharding.
- This modification enables the user to pass custom FSDP arguements for every wrapped module.
+ This modification enables the user to pass custom FSDP arguments for every wrapped module.
The added process_group_cache enables different FSDP modules to, when appropriate, use the
same process group instead of instantiating a new process group.
@@ -468,7 +327,7 @@ def _custom_auto_wrap_t2p0p1(
modified version of
https://github.com/pytorch/pytorch/blob/96ca226a7332be0d8f3d6159d0c797e032ab0721/torch/distributed/fsdp/_wrap_utils.py#L31
FSDP's _auto_wrap recursively wraps modules as FSDP modules for parameter sharding.
- This modification enables the user to pass custom FSDP arguements for every wrapped module.
+ This modification enables the user to pass custom FSDP arguments for every wrapped module.
The added process_group_cache enables different FSDP modules to, when appropriate, use the
same process group instead of instantiating a new process group.
@@ -756,361 +615,421 @@ def _sharded_pre_load_state_dict_hook(
_enter_unshard_params_ctx(module, fsdp_state, writeback=True)
-def fsdp_state_has_default_pg(state: '_FSDPState') -> bool:
- """Indicates whether FlatParamHandle has the default process group.
-
- Args:
- handle (_FSDPState): FSDP State object
-
- Returns:
- bool: True if the ProcessGroup of the _FSDPState object is the default process group. False
- otherwise.
- """
- if state.process_group is None:
- # If no process group is attached to the _FSDPState, assume it uses default process group.
- return True
- return len(get_process_group_ranks(state.process_group)) == dist.get_world_size()
-
-
-def fsdp_state_pg_ranks(state: '_FSDPState') -> Tuple[int, ...]:
- """Gets the ranks included in the ProcessGroup of an _FSDPState.
-
- Args:
- state (_FSDPState): FSDP State object
-
- Returns:
- Tuple[int]: Ranks for the FSDP State's process group.
- """
- if state.process_group is None:
- # If no process group is attached to the _FSDPState, assume it uses default process group.
- return tuple(range(dist.get_world_size()))
- else:
- return tuple(get_process_group_ranks(state.process_group))
-
-
-def _wait_for_computation_stream(
- computation_stream: torch.Stream,
- root_state: '_FSDPState',
- pre_unshard_stream: torch.Stream,
-):
- """Unshard and pre-unshard streams wait for computation stream.
+if version.parse(torch.__version__) > version.parse('2.2.9') and version.parse(
+ torch.__version__) < version.parse('2.3.1'):
+ import copy
+
+ from torch.distributed._tensor import DeviceMesh, DTensor, Replicate
+ from torch.distributed._tensor import Shard as DShard
+ from torch.distributed.algorithms._comm_hooks import default_hooks
+ from torch.distributed.device_mesh import _mesh_resources
+ from torch.distributed.distributed_c10d import _get_default_group
+ from torch.distributed.fsdp._common_utils import _FSDPState
+ from torch.distributed.fsdp._init_utils import (HYBRID_SHARDING_STRATEGIES, ProcessGroupType,
+ _get_default_comm_hook_state, _init_intra_and_inter_node_groups,
+ _is_valid_hybrid_shard_pg_type, _init_extension)
+ from torch.distributed.fsdp.fully_sharded_data_parallel import (_annotate_modules_for_dynamo, _auto_wrap,
+ _check_orig_params_flattened, _init_buffer_state,
+ _init_core_state, _init_device_handle,
+ _init_ignored_module_states,
+ _init_param_handle_from_module,
+ _init_prefetching_state, _init_runtime_state,
+ _init_state_dict_state,
+ _register_all_state_dict_hooks,
+ _register_flat_param)
+ from torch.distributed.fsdp.wrap import CustomPolicy, ModuleWrapPolicy, _Policy
+ from torch.distributed.tensor.parallel.fsdp import DTensorExtensions
+
+ def all_gather_dtensor_t2p3p0(
+ self,
+ tensor: DTensor,
+ parent_mesh: Optional[DeviceMesh],
+ ) -> torch.Tensor:
+ """All gather a DTensor in its FSDP dimension and return the local tensor."""
+ assert parent_mesh == tensor.device_mesh
+
+ placements = list(copy.deepcopy(tensor.placements))
+ # FSDP + TP: [Shard(0), tp_placement] -> [Replicate(), tp_placement]
+ # HSDP + TP: [Replicate(), Shard(0), tp_placement] -> [Replicate(), Replicate(), tp_placement]
+ for i in range(0, len(placements) - 1):
+ placements[i] = Replicate()
+ tensor = tensor.redistribute(
+ device_mesh=tensor.device_mesh,
+ placements=placements,
+ )
+ return tensor.to_local()
- Has the unshard and pre-unshard streams wait for the computation stream.
- For example, this should be called in the FSDP root's pre-forward to
- respect optimizer step computation.
- """
- # Tracing does not need to wait
- if torch.distributed._functional_collectives.is_torchdynamo_compiling():
- return
- # Ensure all unshard streams wait for the computation stream.
- unshard_streams = set()
- for fsdp_state in root_state._all_fsdp_states:
- unshard_streams.add(fsdp_state._unshard_stream)
- for unshard_stream in unshard_streams:
- unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
- # Having the pre-all-gather stream wait for the current stream even if we
- # do not leverage the pre-all-gather stream is tolerable since this only
- # runs once per iteration
- pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
+ def chunk_dtensor_t2p3p0(
+ self,
+ tensor: torch.Tensor,
+ rank: int,
+ device_mesh: DeviceMesh,
+ ) -> DTensor:
+ """Shard a tensor to chunks along the first dimension.
+ The local rank will gets its corresponding chunk as the local tensor to create a DTensor.
+ """
+ parent_mesh = _mesh_resources.get_parent_mesh(device_mesh)
+ if parent_mesh is None:
+ raise RuntimeError('No parent device_mesh is found for FSDP device_mesh.')
+ # if parent_mesh.ndim != 2:
+ # raise RuntimeError(
+ # f"Found parent device_mesh of ndim={parent_mesh.ndim},",
+ # "but only 2D meshes are currently supported.",
+ # )
+
+ # We need to explicitly call .detach() to return a new tensor detached from the current graph.
+ tensor = tensor.clone().detach()
+
+ # When a layer is not involved in TP, then the tensor will not be a DTensor.
+ # e.g. When a layer is not specified in the parallelize_plan, TP will have no effect on the layer.
+ # e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer.
+ if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor):
+
+ # For tensors, it is replicated across tp dimension and sharded across FSDP dimension.
+ # TP is the inner dimension and FSDP is the outer dimension.
+ # Therefore, shard placements for tensor is (Shard(0), Replicate()).
+ replicate_placements = [Replicate() for _ in range(parent_mesh.ndim)]
+ shard_placements = [Replicate() for _ in range(parent_mesh.ndim)]
+ shard_placements[0] = DShard(0) # type: ignore[call-overload]
+
+ return DTensor.from_local(tensor, parent_mesh, replicate_placements).redistribute(
+ device_mesh=parent_mesh,
+ placements=shard_placements,
+ )
-@no_type_check
-def _root_pre_forward(
- state: '_FSDPState',
- module: nn.Module,
- args,
- kwargs,
-) -> None:
- """Runs pre-forward logic specific to the root FSDP instance.
+ else:
+ tp_placements = tensor.placements
+ tp_placement = tp_placements[0]
+
+ tensor = tensor.to_local()
+
+ if parent_mesh.ndim <= 2:
+ # For DTensors, it is sharded across tp dimension first and then sharded across FSDP dimension.
+ # TP is the inner dimension and FSDP is the outer dimension.
+ # Therefore, shard placements for tensor is (Shard(0), tp_placement).
+ replicate_placements = [Replicate() for _ in range(parent_mesh.ndim)]
+ replicate_placements[-1] = tp_placement # type: ignore[call-overload]
+ shard_placements = [DShard(0) for _ in range(parent_mesh.ndim)] # type: ignore[misc]
+ shard_placements[-1] = tp_placement # type: ignore[call-overload]
+
+ elif parent_mesh.ndim == 3:
+ replicate_placements = [Replicate(), Replicate(), tp_placement]
+ shard_placements = [Replicate(), DShard(0), tp_placement] # type: ignore[misc]
+
+ return DTensor.from_local(tensor, parent_mesh, replicate_placements).redistribute(
+ device_mesh=parent_mesh,
+ placements=shard_placements,
+ )
- This should run before any individual module's pre-forward. This starts
- with an attempt at lazy initialization (which only runs non-vacuously once).
- Otherwise, if this is called on a non-root FSDP instance, then it returns
- directly.
- """
- from torch.distributed.fsdp._common_utils import _is_composable
- from torch.distributed.fsdp._runtime_utils import (_cast_buffers_to_dtype_and_device,
- _get_buffers_and_dtypes_for_computation, _lazy_init,
- _reset_flat_param_grad_info_if_needed, _root_cast_forward_input)
- from torch.distributed.utils import _p_assert, _to_kwargs
- with torch.profiler.record_function('FullyShardedDataParallel._root_pre_forward'):
- _lazy_init(state, module)
- _p_assert(state._is_root is not None, 'Expects a root FSDP to have been set')
- if not state._is_root:
- # Always cast forward inputs in the root of this local FSDP unit for mixed
- # precision, as this is where mixed precision could be configed.
- # This is more useful for auto wrapping that is recommended in composable path.
- # For manual wrapping, cast forward inputs on each local FSDP unit root will
- # increase some overhead, so not turned on for model wrapper path right now where
- # manual wrapping is more broadly used.
- if _is_composable(state):
- return _root_cast_forward_input(state, module, args, kwargs)
- return args, kwargs
-
- # We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers
- # are in full precision and if we should cast them back to lower precision, which happens when
- # exiting eval() mode.
- handle = state._handle
- if handle:
- should_cast_buffers_to_full_prec = handle._force_full_precision
+ DTensorExtensions.all_gather_dtensor = all_gather_dtensor_t2p3p0
+ DTensorExtensions.chunk_dtensor = chunk_dtensor_t2p3p0
+
+ def _is_valid_hybrid_shard_device_mesh_t2p3p0(device_mesh: DeviceMesh) -> bool:
+ #parent_mesh = _mesh_resources.get_parent_mesh(device_mesh)
+ #if parent_mesh is not None:
+ # raise RuntimeError(
+ # f"Found device_mesh {device_mesh} passed in has a parent device_mesh {parent_mesh}.",
+ # "Hybrid sharding + TP is not supported yet.",
+ # )
+ return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2
+
+ def _init_process_group_state_for_hybrid_shard_t2p3p0(
+ state: _FSDPState,
+ process_group: ProcessGroupType,
+ device_mesh: DeviceMesh,
+ ) -> _FSDPState:
+ if device_mesh:
+ if _is_valid_hybrid_shard_device_mesh_t2p3p0(device_mesh):
+ state._device_mesh = device_mesh
+ # We currently only allow _inter_node_pg to be the outermost dimension, and the
+ # process_group(intra_node) to be the innermost dimension.
+ state._inter_node_pg = device_mesh.get_group(mesh_dim=0)
+ state.process_group = device_mesh.get_group(mesh_dim=1)
+ else:
+ raise ValueError('Expected device_mesh to have ndim=2 '
+ f'but got {len(device_mesh.get_group())}')
+ elif process_group is None:
+ default_group = _get_default_group()
+ intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(default_group,
+ state._device_handle.device_count())
+ # we shard across intra-node
+ state.process_group = intra_node_group
+ # save _inter_node_pg to allreduce across.
+ state._inter_node_pg = inter_node_group
else:
- should_cast_buffers_to_full_prec = True
+ # Check type and assign state.process_group and state._inter_node_pg.
+ if _is_valid_hybrid_shard_pg_type(process_group):
+ # Assuming that user passed in as intra node group and inter node group
+ # as documented.
+ state.process_group, state._inter_node_pg = process_group
+ else:
+ raise ValueError('Expected process_group to be passed in as either None or '
+ f'Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}')
+ # Create state for allreduce
+ state._inter_node_state = _get_default_comm_hook_state(process_group=state._inter_node_pg,)
+ return state
+
+ def _init_process_group_state_t2p3p0(
+ state: _FSDPState,
+ process_group: ProcessGroupType,
+ sharding_strategy: ShardingStrategy,
+ policy: Optional[_Policy],
+ device_mesh: Optional[DeviceMesh] = None,
+ ) -> _FSDPState:
+ if process_group is not None and device_mesh is not None:
+ raise ValueError('Cannot pass both process_group and device_mesh at the '
+ 'same time. Please just pass only one of them.')
+ is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES
+ if is_hybrid_strategy:
+ if process_group is None and policy is None and device_mesh is None:
+ # Raise an error here, since this is manual wrapping with no process group
+ # passed in, there is no way to ensure all wrapped FSDP instances use the same
+ # process groups.
+ raise ValueError(
+ f'Manual wrapping with {sharding_strategy}',
+ 'requires explicit specification of process group or device_mesh.',
+ )
+ else:
+ state = _init_process_group_state_for_hybrid_shard_t2p3p0(state, process_group, device_mesh)
+ else:
+ if device_mesh:
+ state._device_mesh = device_mesh
+ state.process_group = device_mesh.get_group(mesh_dim=0)
+ else:
+ state.process_group = (process_group if process_group is not None else _get_default_group())
+
+ state.rank = state.process_group.rank()
+ state.world_size = state.process_group.size()
+ data_parallel_world_size = state.world_size
+ if is_hybrid_strategy:
+ data_parallel_world_size *= state._inter_node_pg.size()
+ state._gradient_predivide_factor = (
+ default_hooks.DefaultState._get_gradient_predivide_factor(data_parallel_world_size))
+ state._gradient_postdivide_factor = (data_parallel_world_size / state._gradient_predivide_factor)
+ return state
+
+ def init_fn_t2p3p0(
+ self,
+ module: nn.Module,
+ process_group: ProcessGroupType = None,
+ sharding_strategy: Optional[ShardingStrategy] = None,
+ cpu_offload: Optional[CPUOffload] = None,
+ auto_wrap_policy: Optional[Union[Callable, ModuleWrapPolicy, CustomPolicy]] = None,
+ backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
+ mixed_precision: Optional[MixedPrecision] = None,
+ ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
+ param_init_fn: Optional[Callable[[nn.Module], None]] = None,
+ device_id: Optional[Union[int, torch.device]] = None,
+ sync_module_states: bool = False,
+ forward_prefetch: bool = False,
+ limit_all_gathers: bool = True,
+ use_orig_params: bool = False,
+ ignored_states: Union[Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]] = None,
+ device_mesh: Optional[DeviceMesh] = None,
+ ):
+ """Docstring for lint."""
+ torch._C._log_api_usage_once('torch.distributed.fsdp')
+ super(FullyShardedDataParallel, self).__init__()
+ _init_ignored_module_states(self, module, ignored_modules, ignored_states)
+ _init_device_handle(self, module, self._ignored_params, device_id)
- if should_cast_buffers_to_full_prec:
- _cast_buffers_to_dtype_and_device(
- buffers=dict(module.named_buffers()).values(),
- buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()),
- device=state.compute_device,
- )
- # This flag is only set when we cast buffers to full precision, to avoid the
- # CPU overhead that can stem from retrieving all buffers and their types in the
- # following else branch.
- state._needs_buffer_dtype_restore_check = True
- elif getattr(state, '_needs_buffer_dtype_restore_check', False):
- # Check if buffers are in full precision and we need to cast them
- # back down.
- (
- buffers,
- buffer_dtypes_for_computation,
- ) = _get_buffers_and_dtypes_for_computation(state, module)
- if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0:
- if any(buffer.dtype != buffer_dtype_for_computation
- for buffer, buffer_dtype_for_computation in zip(buffers, buffer_dtypes_for_computation)):
- # Assume we have to cast everything if there is one mismatch
- _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes_for_computation, state.compute_device)
- # We don't have to check this again until we cast buffers to full precision again.
- state._needs_buffer_dtype_restore_check = False
-
- if state.forward_prefetch:
- handles = []
- for fsdp_state in state._all_fsdp_states:
- if fsdp_state._handle:
- handles.append(fsdp_state._handle)
- for handle in handles:
- handle._needs_pre_forward_unshard = True
- handle._prefetched = False
-
- _wait_for_computation_stream(
- state._device_handle.current_stream(),
- state,
- state._pre_unshard_stream,
- )
- _reset_flat_param_grad_info_if_needed(state._all_handles)
-
- # Prepares the forward inputs by moving them to ``compute_device``
- # TODO: Do not use the side stream for tensor copies for now; investigate
- # the perf with/without it.
- with torch.profiler.record_function('FullyShardedDataParallel._to_kwargs'):
- args_tuple, kwargs_tuple = _to_kwargs(args, kwargs, state.compute_device, False)
- args = args_tuple[0]
- kwargs = kwargs_tuple[0]
-
- return _root_cast_forward_input(state, module, args, kwargs)
-
-
-def forward(self, *args: Any, **kwargs: Any) -> Any:
- """Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic."""
- from torch.distributed.fsdp._runtime_utils import (_post_forward, _post_forward_reshard, _pre_forward,
- _pre_forward_unshard)
- from torch.distributed.utils import _p_assert
- handle = self._handle
- with torch.autograd.profiler.record_function('FullyShardedDataParallel.forward'):
- args, kwargs = _root_pre_forward(self, self, args, kwargs)
- unused = None
- args, kwargs = _pre_forward(
+ # Add module annotations for Dynamo support (see function for details)
+ _annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params)
+
+ # Initializes self.process_group, along with rank and world size. This will
+ # also set another attribute, _inter_node_pg, to control the process group
+ # over which sharding occurs, if sharding_strategy is {HYBRID_SHARD, _HYBRID_SHARD_ZERO2}.
+ # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up
+ # the same process group state as the root FSDP module.
+ self._device_mesh = device_mesh
+ _init_process_group_state_t2p3p0(
self,
- handle,
- _pre_forward_unshard,
- self._fsdp_wrapped_module,
- args,
- kwargs,
+ process_group,
+ sharding_strategy,
+ auto_wrap_policy,
+ device_mesh,
)
- if handle:
- _p_assert(
- handle.flat_param.device == self.compute_device,
- 'Expected `FlatParameter` to be on the compute device '
- f'{self.compute_device} but got {handle.flat_param.device}',
+ if auto_wrap_policy is not None:
+ root_kwargs = {
+ 'process_group': process_group,
+ 'sharding_strategy': sharding_strategy,
+ 'cpu_offload': cpu_offload,
+ 'backward_prefetch': backward_prefetch,
+ 'mixed_precision': mixed_precision,
+ 'param_init_fn': param_init_fn,
+ 'device_id': device_id,
+ 'sync_module_states': sync_module_states,
+ 'forward_prefetch': forward_prefetch,
+ 'limit_all_gathers': limit_all_gathers,
+ 'use_orig_params': use_orig_params,
+ 'ignored_states': self._ignored_params,
+ 'device_mesh': device_mesh,
+ }
+ if sharding_strategy in HYBRID_SHARDING_STRATEGIES and device_mesh is None:
+ # Share root process groups with children to maintain
+ # the invariant that all FSDP modules will have the same
+ # process groups.
+ root_kwargs['process_group'] = (self.process_group, self._inter_node_pg)
+
+ _auto_wrap(
+ module,
+ auto_wrap_policy,
+ self._ignored_modules,
+ self._ignored_params,
+ root_kwargs,
+ FullyShardedDataParallel,
)
- output = self._fsdp_wrapped_module(*args, **kwargs)
- return _post_forward(self, handle, _post_forward_reshard, self, unused, output)
-
-
-@no_type_check
-def _share_state_and_init_handle_attrs_t2p1(
- root_state: '_FSDPState',
- root_module: nn.Module,
-) -> None:
- """Shares state from ``root_state`` to other FSDP states.
- Shares data structure state from the ``root_state`` to all FSDP states in
- ``root_module`` 's module tree, and initializes handle attributes. These are
- done together to require a single loop over the states. This function has
- been modified to assign a different unshard stream to each process group.
- """
- from torch.distributed.fsdp._runtime_utils import (HOMOGENEOUS_ATTR_NAMES, _init_device_mesh,
- _validate_and_get_hybrid_shard_state)
- from torch.distributed.utils import _p_assert
-
- handle = root_state._handle
- if handle:
- handle.init_flat_param_attributes()
- _validate_and_get_hybrid_shard_state(root_module)
- attr_name_to_values: Dict[str, Set[Any]] = {}
- for attr_name in HOMOGENEOUS_ATTR_NAMES:
- attr_name_to_values[attr_name] = set()
- root_state._all_handles = root_state._exec_order_data.all_handles # share reference
- root_state._device_mesh = _init_device_mesh(root_state)
- # Update _has_optim_in_backward for each handle.
- for handle in root_state._all_handles:
- flat_param = handle.flat_param
- if hasattr(flat_param, '_in_backward_optimizers'):
- raise RuntimeError('FSDP optimizer in backward only supported with use_orig_params=True!')
- handle._has_optim_in_backward = flat_param._params is not None and any(
- hasattr(param, '_in_backward_optimizers') for param in flat_param._params)
-
- # Patching so that _FSDPStates with different process groups have separate unshard streams.
- # Keep track of any new unshard streams we may have to add for specific process groups.
- fsdp_pg_unshard_streams = {}
- try:
- unshard_priority = root_state._unshard_stream.priority
- except AttributeError:
- # Use the default priority of 0 if the stream has no assigned priority.
- unshard_priority = 0
- for fsdp_state in root_state._all_fsdp_states:
- for attr_name in HOMOGENEOUS_ATTR_NAMES:
- _p_assert(
- hasattr(fsdp_state, attr_name),
- f'FSDP state missing attribute {attr_name}',
- )
- attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name))
- if fsdp_state is root_state:
- continue
- # Relax the assert for non-root FSDP instances in case the nested
- # initialized module is wrapped again in FSDP later (e.g. after
- # training to run inference)
- _p_assert(
- fsdp_state._is_root is None or not fsdp_state._is_root,
- "Non-root FSDP instance's `_is_root` should not have been "
- 'set yet or should have been set to `False`',
+ backward_prefetch_limit = 1
+ forward_prefetch_limit = 1
+ _init_core_state(
+ self,
+ sharding_strategy,
+ mixed_precision,
+ cpu_offload,
+ limit_all_gathers,
+ use_orig_params,
+ backward_prefetch_limit,
+ forward_prefetch_limit,
)
- fsdp_state._is_root = False
-
- # Take care of any new unshard streams we have to create for non-default process groups.
- if fsdp_state_has_default_pg(fsdp_state):
- # If using default process group, unshard stream is the same as root fsdp instance.
- fsdp_state._unshard_stream = root_state._unshard_stream
- else:
- # Otherwise, unshard stream is separate.
- state_pg_ranks = fsdp_state_pg_ranks(fsdp_state)
- if state_pg_ranks in fsdp_pg_unshard_streams:
- # We have created the unshard stream for this process group already. Use it.
- fsdp_state._unshard_stream = fsdp_pg_unshard_streams[state_pg_ranks]
- else:
- # We don't have an unshard stream for this process group yet. Make it.
- fsdp_state._unshard_stream = fsdp_state._device_handle.Stream(priority=unshard_priority)
- fsdp_pg_unshard_streams[state_pg_ranks] = fsdp_state._unshard_stream
-
- # All other stream assignments stay common across all of FSDP.
- fsdp_state._post_backward_stream = root_state._post_backward_stream
- fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream
- fsdp_state._all_reduce_stream = root_state._all_reduce_stream
- fsdp_state._default_stream = root_state._default_stream
- fsdp_state._exec_order_data = root_state._exec_order_data
- fsdp_state._free_event_queue = root_state._free_event_queue
- fsdp_state._device_mesh = root_state._device_mesh
- handle = fsdp_state._handle
- if handle:
- handle.init_flat_param_attributes()
- for attr_name, attr_values in attr_name_to_values.items():
- if len(attr_values) != 1:
- raise ValueError(f'Expects one homogeneous value for {attr_name} but got {attr_values}')
-
+ _init_runtime_state(self)
+ _init_prefetching_state(self, backward_prefetch, forward_prefetch)
+ _init_buffer_state(self, module)
+ # extension needs to be set before `_init_param_handle_from_module()`
+ _init_extension(self, device_mesh)
+ _init_param_handle_from_module(
+ self,
+ module,
+ device_id,
+ param_init_fn,
+ sync_module_states,
+ )
+ self._fsdp_wrapped_module = module
+ if not use_orig_params:
+ _check_orig_params_flattened(self, self._ignored_params)
+ _register_flat_param(self, self)
-@no_type_check
-def _share_state_and_init_handle_attrs_t2p2(
- root_state: '_FSDPState',
- root_module: nn.Module,
-) -> None:
- """Shares state from ``root_state`` to other FSDP states.
+ # `_state_dict_type` controls the `state_dict()` behavior, which is
+ # implemented using post-save and pre-load hooks
+ _init_state_dict_state(self)
+ _register_all_state_dict_hooks(self)
- Shares data structure state from the ``root_state`` to all FSDP states in
- ``root_module`` 's module tree, and initializes handle attributes. These are
- done together to require a single loop over the states. This function has
- been modified to assign a different unshard stream to each process group.
- """
- from torch.distributed.fsdp._runtime_utils import HOMOGENEOUS_ATTR_NAMES, _validate_and_get_hybrid_shard_state
- from torch.distributed.utils import _p_assert
-
- handle = root_state._handle
- if handle:
- handle.init_flat_param_attributes()
- _validate_and_get_hybrid_shard_state(root_module)
- attr_name_to_values: Dict[str, Set[Any]] = {}
- for attr_name in HOMOGENEOUS_ATTR_NAMES:
- attr_name_to_values[attr_name] = set()
- root_state._all_handles = root_state._exec_order_data.all_handles # share reference
- # Update _has_optim_in_backward for each handle.
- for handle in root_state._all_handles:
- flat_param = handle.flat_param
- if hasattr(flat_param, '_in_backward_optimizers'):
- raise RuntimeError('FSDP optimizer in backward only supported with use_orig_params=True!')
- handle._has_optim_in_backward = flat_param._params is not None and any(
- hasattr(param, '_in_backward_optimizers') for param in flat_param._params)
- if handle._has_optim_in_backward:
- torch._C._log_api_usage_once('fsdp.optimizer_in_backward')
-
- # Patching so that _FSDPStates with different process groups have separate unshard streams.
- # Keep track of any new unshard streams we may have to add for specific process groups.
- fsdp_pg_unshard_streams = {}
- try:
- unshard_priority = root_state._unshard_stream.priority
- except AttributeError:
- # Use the default priority of 0 if the stream has no assigned priority.
- unshard_priority = 0
- for fsdp_state in root_state._all_fsdp_states:
- for attr_name in HOMOGENEOUS_ATTR_NAMES:
- _p_assert(
- hasattr(fsdp_state, attr_name),
- f'FSDP state missing attribute {attr_name}',
+ from torch.distributed.checkpoint.state_dict import StateDictOptions, _StateDictInfo
+
+ def _verify_options_t2p3p0(
+ model: nn.Module,
+ optims: Tuple[torch.optim.Optimizer, ...],
+ optim_only: bool,
+ *,
+ submodules: Optional[Set[nn.Module]] = None,
+ options: Optional[StateDictOptions] = None,
+ ) -> _StateDictInfo:
+ """Verify the model and options passed by the user and generates _StateDictInfo."""
+ from torch.distributed.checkpoint.state_dict import StateDictOptions, _get_fqns, _StateDictInfo
+ from torch.distributed.fsdp import FullOptimStateDictConfig, FullStateDictConfig
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+ from torch.distributed.fsdp import (OptimStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig,
+ StateDictConfig, StateDictType)
+
+ if optim_only and not optims:
+ raise RuntimeError('Optimizers are not passed in but optim_only is set to True.')
+
+ options = options or StateDictOptions()
+ assert options is not None # pyright
+
+ fqn_param_mapping: Dict[Union[str, torch.Tensor], Union[Set[str], torch.Tensor]] = {}
+ all_fqns = set()
+ for name, param in model.named_parameters():
+ fqns = _get_fqns(model, name)
+ fqns = {fqn.replace('_checkpoint_wrapped_module.', '') for fqn in fqns}
+ fqn_param_mapping[param] = fqns
+ for fqn in fqns:
+ fqn_param_mapping[fqn] = param
+ all_fqns.add(fqn)
+
+ submodule_prefixes = set()
+ if submodules:
+ submodules = set(submodules)
+ for name, module in model.named_modules():
+ if module not in submodules:
+ continue
+ fqns = _get_fqns(model, name)
+ assert len(fqns) == 1, 'Submodule FQN should only have 1 instance'
+ for fqn in fqns:
+ submodule_prefixes.add(f'{fqn}.')
+ fsdp_modules = FSDP.fsdp_modules(model)
+ state_dict_config: StateDictConfig
+ optim_state_dict_config: OptimStateDictConfig
+ fsdp_context: Callable
+ if fsdp_modules:
+ # FSDP API only work if at least one FSDP instance exists.
+ if options.full_state_dict:
+ state_dict_config = FullStateDictConfig(offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload)
+ optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=options.cpu_offload,
+ rank0_only=options.cpu_offload)
+ state_dict_type = StateDictType.FULL_STATE_DICT
+ else:
+ state_dict_config = ShardedStateDictConfig(offload_to_cpu=options.cpu_offload,)
+ optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=options.cpu_offload,)
+ state_dict_type = StateDictType.SHARDED_STATE_DICT
+
+ fsdp_context = functools.partial(
+ FSDP.state_dict_type,
+ module=model,
+ state_dict_type=state_dict_type,
+ state_dict_config=state_dict_config,
+ optim_state_dict_config=optim_state_dict_config,
)
- attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name))
- if fsdp_state is root_state:
- continue
- # Relax the assert for non-root FSDP instances in case the nested
- # initialized module is wrapped again in FSDP later (e.g. after
- # training to run inference)
- _p_assert(
- fsdp_state._is_root is None or not fsdp_state._is_root,
- "Non-root FSDP instance's `_is_root` should not have been "
- 'set yet or should have been set to `False`',
+ else:
+ fsdp_context = contextlib.nullcontext
+ return _StateDictInfo(
+ **asdict(options),
+ fqn_param_mapping=fqn_param_mapping,
+ all_fqns=all_fqns,
+ submodule_prefixes=submodule_prefixes,
+ fsdp_context=fsdp_context,
+ fsdp_modules=cast(List[nn.Module], fsdp_modules),
+ handle_model=not optim_only,
+ handle_optim=(len(optims) > 0),
)
- fsdp_state._is_root = False
- # Take care of any new unshard streams we have to create for non-default process groups.
- if fsdp_state_has_default_pg(fsdp_state):
- # If using default process group, unshard stream is the same as root fsdp instance.
- fsdp_state._unshard_stream = root_state._unshard_stream
- else:
- # Otherwise, unshard stream is separate.
- state_pg_ranks = fsdp_state_pg_ranks(fsdp_state)
- if state_pg_ranks in fsdp_pg_unshard_streams:
- # We have created the unshard stream for this process group already. Use it.
- fsdp_state._unshard_stream = fsdp_pg_unshard_streams[state_pg_ranks]
- else:
- # We don't have an unshard stream for this process group yet. Make it.
- fsdp_state._unshard_stream = fsdp_state._device_handle.Stream(priority=unshard_priority)
- fsdp_pg_unshard_streams[state_pg_ranks] = fsdp_state._unshard_stream
-
- # All other stream assignments stay common across all of FSDP.
- fsdp_state._post_backward_stream = root_state._post_backward_stream
- fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream
- fsdp_state._all_reduce_stream = root_state._all_reduce_stream
- fsdp_state._default_stream = root_state._default_stream
- fsdp_state._exec_order_data = root_state._exec_order_data
- fsdp_state._free_event_queue = root_state._free_event_queue
- handle = fsdp_state._handle
- if handle:
- handle.init_flat_param_attributes()
- for attr_name, attr_values in attr_name_to_values.items():
- if len(attr_values) != 1:
- raise ValueError(f'Expects one homogeneous value for {attr_name} but got {attr_values}')
+ from torch.distributed.fsdp._optim_utils import FSDPParamInfo
+ from torch.distributed._state_dict_utils import _gather_state_dict
+ def _shard_orig_param_state(
+ fsdp_param_info: FSDPParamInfo,
+ fqn: str,
+ optim_state: Dict[str, Any],
+ ) -> Dict[str, Any]:
+ """Shard function monkeypatch.
+
+ Shard the optimizer state for the original parameter with the name ``fqn``.
+ This API should only be used when ``use_orig_params`` is True.
+ """
+ if not optim_state:
+ return {}
+ fsdp_state = fsdp_param_info.state
+ flat_param = fsdp_param_info.handle.flat_param
+ param_idx = fsdp_param_info.param_indices[fqn]
+ shard_param_info = flat_param._shard_param_infos[param_idx] # type: ignore[attr-defined]
+ optim_state = _gather_state_dict(
+ optim_state,
+ pg=fsdp_state.process_group,
+ device=fsdp_state.compute_device,
+ )
+ if not shard_param_info.in_shard:
+ return {}
+ # Flatten and shard the state.
+ new_optim_state: Dict[str, Any] = {}
+ intra_param_start_idx = shard_param_info.intra_param_start_idx
+ intra_param_end_idx = shard_param_info.intra_param_end_idx
+ for state_name, value in optim_state.items():
+ if (
+ torch.is_tensor(value)
+ and value.dim() > 0
+ and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
+ ):
+ value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator]
+ new_optim_state[state_name] = value
+ torch.cuda.synchronize()
+ return new_optim_state
diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py
index c8c6d325e0..b7c9bd4d4a 100644
--- a/composer/trainer/trainer.py
+++ b/composer/trainer/trainer.py
@@ -28,18 +28,20 @@
import torch.distributed
import torch.nn as nn
import torch.utils.data
-from packaging import version
+from torch._dynamo import OptimizedModule
from torch.cuda.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.nn.parallel import DistributedDataParallel
+from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader, DistributedSampler
from torchmetrics import Metric
-from composer.callbacks import CheckpointSaver, OptimizerMonitor
+from composer.callbacks import CheckpointSaver, MemorySnapshot, OOMObserver, OptimizerMonitor
from composer.core import (Algorithm, AlgorithmPass, Batch, Callback, DataSpec, Engine, Evaluator, Event, Precision,
- PyTorchScheduler, State, Time, Timestamp, TimeUnit, TrainerMode, ensure_data_spec,
- ensure_evaluator, ensure_time, get_precision_context, validate_eval_automicrobatching)
+ State, Time, Timestamp, TimeUnit, TrainerMode, ensure_data_spec, ensure_evaluator,
+ ensure_time, get_precision_context, validate_eval_automicrobatching)
from composer.devices import Device, DeviceCPU, DeviceGPU, DeviceMPS, DeviceTPU
-from composer.loggers import (ConsoleLogger, Logger, LoggerDestination, MosaicMLLogger, ProgressBarLogger,
+from composer.loggers import (ConsoleLogger, Logger, LoggerDestination, MLFlowLogger, MosaicMLLogger, ProgressBarLogger,
RemoteUploaderDownloader, WandBLogger)
from composer.loggers.mosaicml_logger import MOSAICML_ACCESS_TOKEN_ENV_VAR, MOSAICML_PLATFORM_ENV_VAR
from composer.models import ComposerModel
@@ -54,8 +56,9 @@
ensure_tuple, export_with_logger, extract_hparams, format_name_with_dist,
get_composer_env_dict, get_device, get_file, is_tpu_installed, map_collection,
maybe_create_object_store_from_uri, maybe_create_remote_uploader_downloader_from_uri,
- model_eval_mode, parse_uri, reproducibility, using_torch_2)
+ model_eval_mode, parse_uri, partial_format, reproducibility)
from composer.utils.misc import is_model_deepspeed
+from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY
if is_tpu_installed():
import torch_xla.core.xla_model as xm
@@ -66,7 +69,7 @@
__all__ = ['Trainer']
# syntax to shorten the Scheduler type annotations
-Scheduler = Union[ComposerScheduler, PyTorchScheduler]
+Scheduler = Union[ComposerScheduler, LRScheduler]
def _raise_missing_argument_exception(arg_name: str):
@@ -90,7 +93,7 @@ def _scale_max_duration_by_ssr(
def _get_default_scheduler_frequency(schedulers: Optional[Union[Scheduler, Sequence[Scheduler]]]):
- has_pytorch_scheduler = any(isinstance(scheduler, PyTorchScheduler) for scheduler in ensure_tuple(schedulers))
+ has_pytorch_scheduler = any(isinstance(scheduler, LRScheduler) for scheduler in ensure_tuple(schedulers))
if has_pytorch_scheduler:
log.info(('Stepping schedulers every epoch, as a PyTorch scheduler was provided. '
'The trainer cannot automatically convert the parameters (e.g. step_size, T_max) of the '
@@ -124,14 +127,19 @@ def _compile_schedulers(
schedulers: Optional[Union[Scheduler, Sequence[Scheduler]]],
state: State,
scale_schedule_ratio: float,
-) -> List[PyTorchScheduler]:
+) -> List[LRScheduler]:
compiled_schedulers = []
for scheduler in ensure_tuple(schedulers):
- if isinstance(scheduler, PyTorchScheduler):
+ if isinstance(scheduler, LRScheduler):
scale_pytorch_scheduler(scheduler, scale_schedule_ratio)
compiled_schedulers.append(scheduler)
- else: # it's a composer scheduler
- compiled_schedulers.append(compile_composer_scheduler(scheduler, state, scale_schedule_ratio))
+ # It's a composer scheduler
+ else:
+ compiled_schedulers.append(compile_composer_scheduler(
+ scheduler,
+ state,
+ scale_schedule_ratio,
+ ))
return compiled_schedulers
@@ -148,8 +156,7 @@ def _set_evaluator_interval_and_subset_num_batches(
if evaluator.eval_interval is None:
evaluator.eval_interval = eval_interval
eval_dataloader = evaluator.dataloader.dataloader
- if isinstance(eval_dataloader, collections.abc.Sized) and (evaluator.subset_num_batches is None or
- evaluator.subset_num_batches == -1):
+ if isinstance(eval_dataloader, collections.abc.Sized) and evaluator.subset_num_batches == -1:
try:
dataloader_len = len(eval_dataloader)
except TypeError:
@@ -451,7 +458,7 @@ class Trainer:
If ``None``, will be set to ``DecoupledSGDW(model.parameters(), lr=0.1)``. (default: ``None``)
.. seealso:: :mod:`composer.optim` for the different optimizers built into Composer.
- schedulers (PyTorchScheduler | ComposerScheduler | Sequence[PyTorchScheduler | ComposerScheduler], optional):
+ schedulers (LRScheduler | ComposerScheduler | Sequence[LRScheduler | ComposerScheduler], optional):
The learning rate schedulers. If ``[]`` or ``None``, the learning rate will be constant.
(default: ``None``).
@@ -695,6 +702,27 @@ class Trainer:
state. This parameter has no effect if ``save_folder`` is ``None``. (default: ``False``)
.. seealso:: :class:`~.CheckpointSaver`
+ save_ignore_keys (List[str] | (Dict) -> None, optional): A list of paths for the ``state_dict`` of the checkpoint,
+ which, when provided, will be ignored from the state_dict before a checkpoint is saved. Each path is a list
+ of strings specifying the keys to index into ``state_dict`` joined together with `/` as a separator (as PyTorch
+ uses `.` in parameter names). If a prefix is provided, all children are also ignored (see Example 2).
+ See :mod:`composer.core.state` for the structure of state_dict.
+
+ Example 1: ``save_ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"]`` would ignore
+ layer 1 weights and bias.
+
+ Example 2: ``save_ignore_keys = ["state/model/*"]`` would ignore the entire model, which would have the same
+ effect as the previous example if there was only 1 layer.
+
+ Example 3: ``save_ignore_keys = ["state/model/layer*.weights"]`` would ignore all weights in the model.
+
+ Example 4: ``save_ignore_keys = ["state/rank_zero_seed", "rng"]`` would reset all randomness when
+ saving the checkpoint.
+
+ If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify
+ the state_dict before it is loaded.
+
+ (default: ``None``)
save_num_checkpoints_to_keep (int, optional): The number of checkpoints to keep locally. The oldest checkpoints
are removed first. Set to ``-1`` to keep all checkpoints locally. (default: ``-1``)
@@ -824,8 +852,8 @@ def __init__(
# Optimizers and Scheduling
optimizers: Optional[torch.optim.Optimizer] = None,
- schedulers: Optional[Union[ComposerScheduler, PyTorchScheduler, Sequence[Union[ComposerScheduler,
- PyTorchScheduler]]]] = None,
+ schedulers: Optional[Union[ComposerScheduler, LRScheduler, Sequence[Union[ComposerScheduler,
+ LRScheduler]]]] = None,
scale_schedule_ratio: float = 1.0,
step_schedulers_every_batch: Optional[bool] = None,
@@ -861,6 +889,7 @@ def __init__(
save_overwrite: bool = False,
save_interval: Union[str, int, Time, Callable[[State, Event], bool]] = '1ep',
save_weights_only: bool = False,
+ save_ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None,
save_num_checkpoints_to_keep: int = -1,
save_metrics: bool = False,
@@ -922,26 +951,22 @@ def __init__(
_validate_precision(precision, device)
# check if provided model is compiled or not
- is_torch_2_0 = using_torch_2()
is_model_compiled = False
- if is_torch_2_0:
- from torch._dynamo import OptimizedModule
- if isinstance(model, OptimizedModule):
- log.warning(f'Provided `model` is already compiled with `torch.compile`. Ignoring ' +
- f'parameter `compile_config` if provided. If you would like `Trainer` ' +
- f'to takes care of model compilation, provide a not-compiled model and ' +
- f'`compile_config` parameter.')
- # The `torch.compile` function returns an object of type `torch._dynamo.OptimizedModule`
- # which wraps the original `nn.Module` object and later patches its forward method to
- # optimized `self.forward` method.
- is_model_compiled = True
- compiled_model = model._orig_mod
- if not isinstance(compiled_model, ComposerModel):
- raise ValueError(f'Provided `model` must be a subclass of ComposerModel. ' +
- f'Instead found as type `{type(compiled_model)}`')
- compiled_model.forward = model.dynamo_ctx(
- compiled_model.forward) # pyright: ignore [reportGeneralTypeIssues]
- model = compiled_model
+ if isinstance(model, OptimizedModule):
+ log.warning(f'Provided `model` is already compiled with `torch.compile`. Ignoring ' +
+ f'parameter `compile_config` if provided. If you would like `Trainer` ' +
+ f'to takes care of model compilation, provide a not-compiled model and ' +
+ f'`compile_config` parameter.')
+ # The `torch.compile` function returns an object of type `torch._dynamo.OptimizedModule`
+ # which wraps the original `nn.Module` object and later patches its forward method to
+ # optimized `self.forward` method.
+ is_model_compiled = True
+ compiled_model = model._orig_mod
+ if not isinstance(compiled_model, ComposerModel):
+ raise ValueError(f'Provided `model` must be a subclass of ComposerModel. ' +
+ f'Instead found as type `{type(compiled_model)}`')
+ compiled_model.forward = model.dynamo_ctx(compiled_model.forward)
+ model = compiled_model
# Microbatching
auto_microbatching = _is_auto_microbatching(device_train_microbatch_size, device=device)
@@ -1047,6 +1072,15 @@ def __init__(
loggers.append(remote_ud)
self.state.profiler.bind_to_state(self.state)
+ # MemorySnapshot, OOMObserver
+ for cb in self.state.callbacks:
+ if isinstance(cb, MemorySnapshot) or isinstance(cb, OOMObserver):
+ if cb.remote_file_name:
+ remote_ud = maybe_create_remote_uploader_downloader_from_uri(uri=cb.remote_file_name,
+ loggers=loggers)
+ if remote_ud is not None:
+ loggers.append(remote_ud)
+
if progress_bar and log_to_console:
warnings.warn(
'Setting both `progress_bar` and `log_to_console` both to True is not recommended and will'
@@ -1085,6 +1119,11 @@ def __init__(
mosaicml_logger = MosaicMLLogger()
loggers.append(mosaicml_logger)
+ # Remote Uploader Downloader
+ # Keep the ``RemoteUploaderDownloader`` below client-provided loggers so the loggers init callbacks run before
+ # the ``RemoteUploaderDownloader`` init. This is necessary to use an ``MLFlowObjectStore`` to log objects to a
+ # run managed by an ``MLFlowLogger``, as the ``MLFlowObjectStore`` relies on the ``MLFlowLogger`` to initialize
+ # the active MLFlow run.
if save_folder is not None:
remote_ud = maybe_create_remote_uploader_downloader_from_uri(save_folder, loggers)
if remote_ud is not None:
@@ -1144,6 +1183,7 @@ def __init__(
latest_remote_file_name=latest_remote_file_name,
overwrite=save_overwrite,
weights_only=save_weights_only,
+ ignore_keys=save_ignore_keys,
save_interval=save_interval,
num_checkpoints_to_keep=save_num_checkpoints_to_keep,
)
@@ -1153,14 +1193,41 @@ def __init__(
self.engine = Engine(state=self.state, logger=self.logger, algorithm_passes=algorithm_passes)
# Set the logger
- self.state.model.logger = self.logger
+ self.state.model.logger = self.logger # pyright: ignore[reportGeneralTypeIssues]
# Run Event.INIT
self.engine.run_event(Event.INIT)
+ # If the experiment is being tracked with an `MLFlowLogger`, then MLFlow experiment and run are available
+ # after Event.INIT.
+ if save_folder is not None:
+ mlflow_logger = None
+ for destination in self.logger.destinations:
+ if isinstance(destination, MLFlowLogger):
+ mlflow_logger = destination
+ break
+
+ if mlflow_logger is not None:
+ mlflow_experiment_id = mlflow_logger._experiment_id
+ mlflow_run_id = mlflow_logger._run_id
+
+ # The save folder and related paths/filenames may contain format placeholders for the MLFlow IDs, so
+ # populate them now.
+ mlflow_format_kwargs = {
+ MLFLOW_EXPERIMENT_ID_FORMAT_KEY: mlflow_experiment_id,
+ MLFLOW_RUN_ID_FORMAT_KEY: mlflow_run_id
+ }
+
+ save_folder = partial_format(save_folder, **mlflow_format_kwargs)
+ if latest_remote_file_name is not None:
+ latest_remote_file_name = partial_format(latest_remote_file_name, **mlflow_format_kwargs)
+
# Log hparams.
if self.auto_log_hparams:
- self.local_hparams = extract_hparams(locals())
+ locs = locals()
+ if 'cb' in locs:
+ del locs['cb']
+ self.local_hparams = extract_hparams(locs)
self.logger.log_hyperparameters(self.local_hparams)
# Log composer version
@@ -1272,10 +1339,6 @@ def __init__(
self.state.scaler = ClosureGradScaler() if self._use_closures() else GradScaler()
if self.state.fsdp_config is not None:
- if version.parse(torch.__version__) < version.parse('1.13.0'):
- raise RuntimeError('To use FSDP with Composer, you must use torch>=1.13.0.')
- from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
-
# This state should never be reached, but we raise a ValueError just in case
if self._use_closures() and self.state.precision == Precision.AMP_FP16:
raise ValueError(f'Using closures and precision {self.state.precision} is not supported'
@@ -1296,7 +1359,8 @@ def __init__(
# FSDP wrap if not using monolith checkpoint on rank 0 only
if self.state.fsdp_config is not None and fsdp_auto_wrap and not self.state.load_fsdp_monolith_rank0_only:
- prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching)
+ with reproducibility.seed_context(self.state.rank_zero_seed):
+ prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching)
# Configure Deepspeed
if self.state.deepspeed_config is not None:
@@ -1317,9 +1381,11 @@ def __init__(
self.state.deepspeed_config = _parse_deepspeed_config(self.state.deepspeed_config, state=self.state)
optimizer = ensure_tuple(self.state.optimizers)[0]
log.debug('Initializing deepspeed')
- (self.state.model, self.state.optimizers, _, _) = deepspeed.initialize(config=self.state.deepspeed_config,
- model=self.state.model,
- optimizer=optimizer)
+ (self.state.model, self.state.optimizers, _, _) = deepspeed.initialize(
+ config=self.state.deepspeed_config,
+ model=self.state.model,
+ optimizer=optimizer,
+ )
# Since the DeepSpeed ZeRO optimizer does not inherit torch.optim.Optimizer, the schedulers must be
# compiled and bound BEFORE DeepSpeed initialization. However, this is OK, as the the DeepSpeed Zero
# optimizer uses the same underlying parameter groups as the original optimizer. See
@@ -1333,6 +1399,8 @@ def __init__(
if 'optimizers' in self.state.serialized_attributes:
self.state.serialized_attributes.remove('optimizers')
+ self.engine.run_event(Event.BEFORE_LOAD)
+
# Load Checkpoint
self._rng_state = None
# If autoresume is enabled, first check for existing checkpoints to load
@@ -1347,8 +1415,6 @@ def __init__(
'latest existing checkpoint in `save_folder`. ')
if save_latest_filename is None:
error_message += 'The `save_latest_filename` must be specified so autoresume knows where to load checkpoints from. '
- if run_name is None:
- error_message += 'The `run_name` must be specified when using autoresume so Event.INIT is run with the correct run name. '
if error_message != '':
raise ValueError(error_message)
assert save_folder is not None
@@ -1443,14 +1509,15 @@ def __init__(
# FSDP wrap if model is not yet wrapped and FSDP is enabled. This can happen if
# load_fsdp_monolith_rank0_only=True but no checkpoint was loaded.
if not self.state.fsdp_enabled and self.state.fsdp_config is not None and self.state.fsdp_auto_wrap and self.state.load_fsdp_monolith_rank0_only:
- prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching)
+ with reproducibility.seed_context(self.state.rank_zero_seed):
+ prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching)
self.engine.run_event(Event.AFTER_LOAD)
# reseed here. This helps with a couple of issues:
- # 1. rng state may change at Event.INIT/Event.AFTER_LOAD. For example, if an algorithm
- # creates a new module and module parameters are initialized randomly, rng state will
- # change. This reseeding nullifies such effects.
+ # 1. rng state may change at Event.INIT/Event.BEFORE_LOAD/Event.AFTER_LOAD. For example,
+ # if an algorithm creates a new module and module parameters are initialized randomly, rng
+ # state will change. This reseeding nullifies such effects.
# 2. While resuming from a checkpoint, we want to spin dataloader and bring it back to the
# same state as at the time of the checkpoint. Therefore, spinning needs to start from the
# same rng state as in the original run.
@@ -1463,9 +1530,8 @@ def __init__(
# The model would need to be torch.compile()'d after being wrapped in a distributed strategy
# to take advantage of any graph breaks.
- if is_torch_2_0 and not is_model_compiled and compile_config is not None:
- compiled_model = torch.compile( # pyright: ignore [reportGeneralTypeIssues]
- self.state.model, **compile_config)
+ if not is_model_compiled and compile_config is not None:
+ compiled_model = torch.compile(self.state.model, **compile_config)
self.state.model = compiled_model._orig_mod
self.state.model.forward = compiled_model.dynamo_ctx(self.state.model.forward)
is_model_compiled = True
@@ -1473,10 +1539,6 @@ def __init__(
# debugging purpose and for unit test.
if self.auto_log_hparams:
self.local_hparams['is_model_compiled'] = is_model_compiled
- elif not is_torch_2_0 and compile_config is not None:
- raise ValueError(f'`torch.compile` is supported for PyTorch 2.0 or higher.' +
- f'Either update your PyTorch version or disable parameter by providing ' +
- f'`compile_config` to `None`.')
@property
def saved_checkpoints(self) -> List[str]:
@@ -1625,8 +1687,8 @@ def fit(
reset_time: bool = False,
# Schedulers
- schedulers: Optional[Union[ComposerScheduler, PyTorchScheduler, Sequence[Union[ComposerScheduler,
- PyTorchScheduler]]]] = None,
+ schedulers: Optional[Union[ComposerScheduler, LRScheduler, Sequence[Union[ComposerScheduler,
+ LRScheduler]]]] = None,
scale_schedule_ratio: float = 1.0,
step_schedulers_every_batch: Optional[bool] = None,
@@ -1739,7 +1801,7 @@ def fit(
If ``reset_time`` is True, then :attr:`.State.max_duration` will be set to this parameter.
optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional): See :class:`.Trainer`.
- schedulers (PyTorchScheduler | ComposerScheduler | Sequence[PyTorchScheduler | ComposerScheduler], optional): See :class:`.Trainer`.
+ schedulers (LRScheduler | ComposerScheduler | Sequence[LRScheduler | ComposerScheduler], optional): See :class:`.Trainer`.
scale_schedule_ratio (float, optional): See :class:`.Trainer`.
step_schedulers_every_batch (bool, optional): See :class:`.Trainer`.
eval_dataloader (Iterable | DataSpec | Evaluator | Sequence[Evaluator], optional): See :class:`.Trainer`.
@@ -1784,6 +1846,7 @@ def fit(
if self.state.max_duration is None:
_raise_missing_argument_exception('max_duration')
+ assert self.state.max_duration is not None
if self.state.dataloader_len is None and self.state.max_duration.unit == TimeUnit.EPOCH:
raise ValueError(
@@ -1932,6 +1995,7 @@ def _compute_and_log_metrics(self, dataloader_label: str, metrics: Dict[str, Met
for metric_name, metric in metrics.items():
assert isinstance(metric, Metric)
if dataloader_label == 'train':
+ assert self.state.train_metrics is not None
self.state.train_metrics[metric_name] = metric
self.state.train_metric_values[metric_name] = computed_metrics[metric_name]
else:
@@ -2022,6 +2086,7 @@ def _train_loop(self) -> None:
# asserted to be not None when Trainer.fit() is called
raise RuntimeError('max_duration must be specified when initializing the Trainer')
+ log.debug('Starting training loop')
while self.state.timestamp < self.state.max_duration:
if int(self.state.timestamp.batch_in_epoch) == 0:
self.engine.run_event(Event.EPOCH_START)
@@ -2070,7 +2135,7 @@ def _train_loop(self) -> None:
self.state.scaler.update()
# total_loss_dict can be None if gradient scaling failed
- if total_loss_dict is not None:
+ if total_loss_dict is not None: # pyright: ignore[reportUnnecessaryComparison]
map_collection(total_loss_dict, dist.all_reduce)
total_loss_dict = {
k: loss.cpu().item() / dist.get_world_size() for k, loss in total_loss_dict.items()
@@ -2099,7 +2164,7 @@ def _train_loop(self) -> None:
for scheduler in self.state.schedulers:
scheduler.step()
- if self.state.train_metrics is not None:
+ if self.state.train_metrics is not None: # pyright: ignore[reportUnnecessaryComparison]
self._compute_and_log_metrics(
dataloader_label='train',
metrics=self.state.train_metrics,
@@ -2134,7 +2199,7 @@ def _train_loop(self) -> None:
# This happens if the "break" did not trigger above, or if it
# did (e.g. duration specified in samples/batches/tokens), but it is still
# the end of the dataloader (i.e. next(dataloader) would raise StopIteration)
- if self.state.train_metrics is not None:
+ if self.state.train_metrics is not None: # pyright: ignore[reportUnnecessaryComparison]
self.state.train_metrics = self._ensure_metrics_device_and_dtype(self.state.train_metrics)
self._compute_and_log_metrics(
dataloader_label='train',
@@ -2230,7 +2295,7 @@ def _train_batch(self, use_grad_scaling: bool) -> Dict[str, torch.Tensor]:
while True:
# Reset train_metrics on every batch
# Placing reset here ensures that if auto grad accum catches an OOM, incomplete metric state is cleared
- if self.state.train_metrics is not None:
+ if self.state.train_metrics is not None: # pyright: ignore[reportUnnecessaryComparison]
for metric in self.state.train_metrics.values():
metric.reset()
@@ -2454,6 +2519,7 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int,
else:
microbatch_loss = self.state.device.tensor_to_device(torch.zeros(size=(1,)))
for loss in ensure_tuple(self.state.loss):
+ assert isinstance(loss, torch.Tensor)
microbatch_loss.add_(loss.mean())
# Copy the loss if it is a dictionary
@@ -2471,7 +2537,7 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int,
microbatch_loss_dict[k] = loss.detach().clone().mean() * (microbatch_num_samples / current_batch_size)
if use_grad_scaling:
- microbatch_loss = cast(torch.Tensor, self.state.scaler.scale(microbatch_loss))
+ microbatch_loss = cast(torch.Tensor, self.state.scaler.scale(microbatch_loss)) # type: ignore
if self.state.deepspeed_enabled:
self.state.deepspeed_model.backward(microbatch_loss)
@@ -2483,7 +2549,8 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int,
self.engine.run_event(Event.AFTER_BACKWARD)
# Use microbatch outputs to update training metrics
- if self.state.train_metrics is not None and len(self.state.train_metrics) != 0:
+ if (self.state.train_metrics is not None and # pyright: ignore[reportUnnecessaryComparison]
+ len(self.state.train_metrics) != 0):
self.state.train_metrics = self._ensure_metrics_device_and_dtype(self.state.train_metrics)
self._eval_train_metrics(device_batch)
@@ -2592,8 +2659,7 @@ def predict_batch_end(self, state: State, logger: Logger) -> None:
self.state.batch = self.state.device.batch_to_device(self.state.batch)
# Perform any device transforms
- if data_spec.device_transforms is not None:
- self.state.batch = data_spec.device_transforms(self.state.batch)
+ self.state.batch = data_spec.device_transforms(self.state.batch)
# Count the batch size and num tokens before any events run
rank_num_samples = data_spec.get_num_samples_in_batch(self.state.batch)
@@ -2856,8 +2922,7 @@ def _eval_loop(
for self.state.batch in self._iter_dataloader(TrainerMode.EVAL):
self.state.batch = self.state.device.batch_to_device(self.state.batch)
- if data_spec.device_transforms is not None:
- self.state.batch = data_spec.device_transforms(self.state.batch)
+ self.state.batch = data_spec.device_transforms(self.state.batch)
# Count the batch size and num tokens before any events run
rank_num_samples = data_spec.get_num_samples_in_batch(self.state.batch)
@@ -2927,6 +2992,13 @@ def _eval_loop(
outputs[k] = v.cpu()
else:
outputs[k] = v
+ elif isinstance(self.state.outputs, Sequence):
+ outputs = []
+ for v in self.state.outputs:
+ if isinstance(v, torch.Tensor):
+ outputs.append(v.cpu())
+ else:
+ outputs.append(v)
else:
outputs = self.state.outputs.cpu()
else:
@@ -3070,7 +3142,7 @@ def _use_closures(self) -> bool:
if self.state.precision != Precision.AMP_FP16:
return True
- if self.state.optimizers is None:
+ if not hasattr(self.state, 'optimizers'):
raise RuntimeError('state.optimizers must be set before `_use_closures` can be determined')
return all(
diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py
index 30930250d9..67ed33cdd3 100644
--- a/composer/utils/__init__.py
+++ b/composer/utils/__init__.py
@@ -6,7 +6,8 @@
from composer.utils.auto_log_hparams import (convert_flat_dict_to_nested_dict, convert_nested_dict_to_flat_dict,
extract_hparams)
from composer.utils.batch_helpers import batch_get, batch_set
-from composer.utils.checkpoint import PartialFilePath, load_checkpoint, safe_torch_load, save_checkpoint
+from composer.utils.checkpoint import (PartialFilePath, get_save_filename, load_checkpoint, safe_torch_load,
+ save_checkpoint)
from composer.utils.collect_env import (configure_excepthook, disable_env_report, enable_env_report,
get_composer_env_dict, print_env)
from composer.utils.device import get_device, is_hpu_installed, is_tpu_installed
@@ -20,7 +21,7 @@
from composer.utils.inference import ExportFormat, Transform, export_for_inference, export_with_logger, quantize_dynamic
from composer.utils.iter_helpers import IteratorFileStream, ensure_tuple, map_collection
from composer.utils.misc import (create_interval_scheduler, get_free_tcp_port, is_model_deepspeed, is_model_fsdp,
- is_notebook, model_eval_mode, using_torch_2)
+ is_notebook, model_eval_mode, partial_format)
from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore,
ObjectStoreTransientError, OCIObjectStore, S3ObjectStore, SFTPObjectStore,
UCObjectStore)
@@ -42,11 +43,12 @@
'LibcloudObjectStore',
'S3ObjectStore',
'SFTPObjectStore',
+ 'MLFlowObjectStore',
'OCIObjectStore',
'GCSObjectStore',
'UCObjectStore',
- 'MLFlowObjectStore',
'MissingConditionalImportError',
+ 'get_save_filename',
'import_object',
'is_model_deepspeed',
'is_model_fsdp',
@@ -84,10 +86,10 @@
'extract_hparams',
'convert_nested_dict_to_flat_dict',
'convert_flat_dict_to_nested_dict',
- 'using_torch_2',
'create_interval_scheduler',
'EvalClient',
'LambdaEvalClient',
'LocalEvalClient',
'MosaicMLLambdaEvalClient',
+ 'partial_format',
]
diff --git a/composer/utils/batch_helpers.py b/composer/utils/batch_helpers.py
index c897fccd5c..5778776dd2 100644
--- a/composer/utils/batch_helpers.py
+++ b/composer/utils/batch_helpers.py
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
"""Helpers to get items and set items in a batch."""
+from __future__ import annotations
from operator import attrgetter, itemgetter
from typing import Any, Callable, Sequence, Union, cast
@@ -9,7 +10,7 @@
__all__ = ['batch_get', 'batch_set']
-def batch_get(batch: Any, key: Union[str, int, Callable, Any]):
+def batch_get(batch: Any, key: Union[str, int, tuple[Callable, Callable], Callable, Any]):
"""Indexes into the batch given the key.
>>> from composer.utils.batch_helpers import batch_get
@@ -27,7 +28,7 @@ def batch_get(batch: Any, key: Union[str, int, Callable, Any]):
Can be any abritrary type that user creates, but we assume some sort of
sequence (list, tuple, tensor, array), mapping (dictionary),
or attribute store (object with data members, namedtuple).
- key (str | int | Tuple[Callable, Callable] | Any, optional): A key to index into the batch or a
+ key (str | int | Tuple[Callable, Callable] | Callable | Any, optional): A key to index into the batch or a
user-specified function to do the extracting. A pair of callables is also
supported for cases where a get and set function pair are both passed
(like in Algorithms). The getter is assumed to be the first of the pair.
@@ -58,7 +59,7 @@ def batch_get(batch: Any, key: Union[str, int, Callable, Any]):
return attrgetter(*key)(batch)
-def batch_set(batch: Any, key: Union[str, int, Callable, Any], value: Any) -> Any:
+def batch_set(batch: Any, key: Union[str, int, tuple[Callable, Callable], Callable, Any], value: Any) -> Any:
"""Indexes into the batch given the key and sets the element at that index to value.
This is not an in-place operation for batches of type tuple as tuples are not mutable.
@@ -83,7 +84,7 @@ def batch_set(batch: Any, key: Union[str, int, Callable, Any], value: Any) -> An
Can be any abritrary type that user creates, but we assume some sort of
sequence (list, tuple, tensor, array), mapping (dictionary),
or attribute store (object with data members, namedtuple).
- key (str | int | Tuple[Callable, Callable] | Any, optional): A key to index into the batch or a user-specified function
+ key (str | int | Tuple[Callable, Callable] | Callable | Any, optional): A key to index into the batch or a user-specified function
to do the setting. A pair of callables is also supported for cases where a get
and set function pair are both passed (like in Algorithms). The setter is
assumed to be the second of the pair.
diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py
index f7610b6daf..a50a2db27d 100644
--- a/composer/utils/checkpoint.py
+++ b/composer/utils/checkpoint.py
@@ -16,15 +16,20 @@
import warnings
from importlib import import_module
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import torch
from packaging import version
+from torch.distributed import checkpoint as dist_cp
+from torch.distributed._tensor import DeviceMesh
+from torch.distributed.checkpoint.metadata import Metadata
+from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
+from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner
from composer.utils import dist, reproducibility
from composer.utils.file_helpers import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, format_name_with_dist,
format_name_with_dist_and_time, get_file, is_tar)
-from composer.utils.misc import is_model_deepspeed, using_torch_2
+from composer.utils.misc import is_model_deepspeed, partial_format
from composer.utils.object_store import ObjectStore
if TYPE_CHECKING:
@@ -33,7 +38,7 @@
log = logging.getLogger(__name__)
-__all__ = ['load_checkpoint', 'save_checkpoint', 'download_checkpoint']
+__all__ = ['get_save_filename', 'load_checkpoint', 'save_checkpoint', 'download_checkpoint']
_COMPOSER_STATES_FILENAME = 'composer_states.pt'
_DEEPSPEED_TAG = 'deepspeed' # always tag with the same, deterministic name. We'll rename the tarball to the appropriate name.
@@ -131,6 +136,178 @@ def _get_write_mode(name: str) -> str:
raise ValueError(f'{name} does not end with a valid tarfile extension.')
+def _get_num_ranks_that_saved_rng(metadata: Metadata):
+ rng_inds = []
+ for field_name, field_value in metadata.planner_data.items():
+ if 'rng' in field_name:
+ _, rng_rank_index, _ = field_value
+ rng_inds.append(rng_rank_index)
+ rng_inds = set(rng_inds)
+ return len(rng_inds)
+
+
+class FileSystemReaderWithValidation(dist_cp.FileSystemReader):
+ """FileSystemReader that validates checkpoint files prior to reading."""
+
+ def __init__(self, path: str):
+ if _get_checkpoint_validation_function() is None:
+ log.info('No checkpoint validation function found when loading sharded checkpoints.')
+ super().__init__(path)
+
+ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
+ """Reads data file.
+
+ Raises:
+ ValueError if the data file is invalid.
+ """
+ validated_checkpoint_paths = set()
+ for read_item in plan.items:
+ data_path = self.path / self.storage_data[read_item.storage_index].relative_path
+ if data_path in validated_checkpoint_paths:
+ continue
+ _ensure_valid_checkpoint(data_path)
+ validated_checkpoint_paths.add(data_path)
+ return super().read_data(plan, planner)
+
+ def read_metadata(self) -> Metadata:
+ """Reads metadata file.
+
+ Raises:
+ ValueError if the metadata file is invalid.
+ """
+ metadata_file_path = self.path / '.metadata'
+ _ensure_valid_checkpoint(metadata_file_path)
+ return super().read_metadata()
+
+
+# A subclass of FileSystemReaderWithValidation that downloads files from the object store before reading them from the local filesystem.
+class DistCPObjectStoreReader(FileSystemReaderWithValidation):
+
+ def __init__(self, source_path: str, destination_path: str, object_store: Union[ObjectStore, LoggerDestination],
+ device_mesh: Optional[DeviceMesh]):
+ self.source_path = source_path
+ self.destination_path = destination_path
+ self.object_store = object_store
+ self.device_mesh = device_mesh
+
+ # Download metadata file.
+ Path(self.destination_path).mkdir(parents=True, exist_ok=True)
+ metadata_destination = os.path.join(self.destination_path, '.metadata')
+ if dist.get_local_rank() == 0:
+ metadata_path = str(Path(source_path) / Path('.metadata'))
+ if isinstance(object_store, ObjectStore):
+ object_store.download_object(
+ object_name=metadata_path,
+ filename=metadata_destination,
+ )
+ else:
+ object_store.download_file(
+ remote_file_name=metadata_path,
+ destination=metadata_destination,
+ )
+ dist.barrier()
+
+ # FileSystemReader takes in a root directory in its constructor, which is the dir where
+ # the metadata is expected to be stored. Also, this is parent directory for any shard file relative paths
+ # specified in the metadata file.
+ super().__init__(destination_path)
+
+ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
+ # Download files if not using HSDP or if on first replica with HSDP enabled
+ first_replica = self.device_mesh is None or self.device_mesh.ndim == 1 or (
+ self.device_mesh.ndim >= 2 and self.device_mesh.get_local_rank(mesh_dim=0) == 0)
+
+ # 1. Collect the relative paths to download for all ranks for deduplication
+ relative_file_paths = set()
+ for plan_item in plan.items:
+ relative_file_paths.add(self.storage_data[plan_item.storage_index].relative_path)
+ all_file_paths = dist.all_gather_object(relative_file_paths)
+
+ # 2. Download to the destination all files this rank needs if on first replica
+ if first_replica:
+ log.debug(f'Rank {dist.get_global_rank()} starting to download files.')
+
+ # Get the lowest rank in the current node
+ local_rank_0 = dist.get_global_rank() - dist.get_local_rank()
+
+ for plan_item in plan.items:
+ relative_file_path = self.storage_data[plan_item.storage_index].relative_path
+ # Check if the file is scheduled to be downloaded by a lower rank on the same node
+ # i.e. if rank 0 and rank 1 on the same node have the same the same required file,
+ # only rank 0 should download it and not rank 1.
+ is_downloaded = any(
+ relative_file_path in all_file_paths[i] for i in range(local_rank_0, dist.get_global_rank()))
+
+ # Download the shard file to the relative path it's associated to and save that relative path
+ # to the root directory specified to the FileSystem reader constructor.
+ file_destination = str(Path(self.destination_path) / Path(relative_file_path))
+
+ # The file could have already been downloaded as different plan items can point to same file.
+ if not is_downloaded and not os.path.exists(file_destination):
+ log.debug(f'Downloading {relative_file_path} to {file_destination}.')
+ object_name = str(Path(self.source_path) / Path(relative_file_path))
+ if isinstance(self.object_store, ObjectStore):
+ self.object_store.download_object(
+ object_name=object_name,
+ filename=file_destination,
+ )
+ else:
+ self.object_store.download_file(
+ remote_file_name=object_name,
+ destination=file_destination,
+ )
+ log.debug(f'Finished downloading {relative_file_path} to {file_destination}.')
+
+ # 3. Wait for all ranks to finish.
+ log.debug(f'Rank {dist.get_global_rank()} finished downloading all files.')
+ dist.barrier()
+ log.debug('Done waiting for all ranks to finish downloading files.')
+
+ # 4. Broadcast files to all other replicas if HSDP
+ if self.device_mesh is not None and self.device_mesh.ndim == 2:
+ # Broadcast file to all replicas
+ replicate_process_group = self.device_mesh.get_group(0)
+ shard_size = self.device_mesh.size(1)
+ rank_in_first_replica = dist.get_global_rank() % shard_size
+ sender = dist.get_global_rank() == rank_in_first_replica
+ receiver = dist.get_global_rank() != rank_in_first_replica
+
+ # Send list of files to all ranks
+ file_list = [sorted(os.listdir(self.destination_path))]
+ dist.broadcast_object_list(file_list, src=rank_in_first_replica, group=replicate_process_group)
+ file_list = file_list[0]
+ log.debug(f'List of files to broadcast: {file_list}')
+
+ # Send each file to the appropriate rank
+ for file_name in file_list:
+ if 'metadata' in file_name: # All ranks already have the metadata file
+ continue
+ if dist.get_local_rank() == 0: # Only 1 rank per node needs to transfer file
+ full_path = os.path.join(self.destination_path, file_name)
+ log.debug(f'Transferring {full_path=}')
+ file_object = [None]
+ if sender:
+ with open(full_path, 'rb') as f:
+ file_object = [{'content': f.read()}]
+ dist.broadcast_object_list(file_object,
+ src=dist.get_global_rank() % shard_size,
+ group=replicate_process_group)
+ received_file_object = file_object[0]
+ assert received_file_object is not None
+ if receiver and not os.path.exists(full_path):
+ with open(full_path, 'wb') as f:
+ f.write(received_file_object['content'])
+
+ log.debug(f'Rank {dist.get_global_rank()} finished transferring files to all ranks.')
+ dist.barrier()
+ log.debug(
+ f'Done waiting for all ranks to finish transferring files. Local checkpoint files: {os.listdir(self.destination_path)}'
+ )
+
+ # 5. Piggyback off of the FileSystemReader to read all the files now that they are downloaded.
+ return super().read_data(plan, planner)
+
+
class PartialFilePath:
def __init__(self, filename: str, folder: Optional[str] = None):
@@ -170,7 +347,16 @@ def is_checkpoint_legacy_sharded(object_store: Optional[ObjectStore], source_pat
try:
with tempfile.TemporaryDirectory() as temp_dir:
metadata_destination = os.path.join(str(temp_dir), '.metadata')
- object_store.download_object(object_name=metadata_path, filename=metadata_destination)
+ if isinstance(object_store, ObjectStore):
+ object_store.download_object(
+ object_name=metadata_path,
+ filename=metadata_destination,
+ )
+ else:
+ object_store.download_file(
+ remote_file_name=metadata_path,
+ destination=metadata_destination,
+ )
return False
except FileNotFoundError:
return True
@@ -276,6 +462,7 @@ def load_checkpoint(
Optional[list[dict[str, Any]]]: The RNG state dicts, indexed by global rank, if
:attr:`load_weights_only` is not None. Otherwise, None.
"""
+ path = partial_format(path, run_name=state.run_name)
using_legacy_sharded = False
if state.fsdp_elastic_sharded_enabled:
assert object_store is None or isinstance(
@@ -367,12 +554,7 @@ def load_sharded_checkpoint(
ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None,
exclude_algorithms: Optional[list[str]] = None,
algorithm_passes: Optional[list[AlgorithmPass]] = None,
-) -> list[dict]:
-
- if not using_torch_2():
- raise ValueError(
- f'Sharded checkpoint loading requires torch version >= 2.0.0. You have torch version {torch.__version__}')
-
+) -> Union[list[dict], None]:
using_multinode = dist.get_world_size() != dist.get_local_world_size()
if not version.parse(torch.__version__) >= version.parse('2.0.1') and using_multinode:
raise ValueError(
@@ -381,93 +563,6 @@ def load_sharded_checkpoint(
if state.fsdp_config is None:
raise ValueError('Loading a sharded checkpoint requires passing an FSDP config to Trainer.')
- load_planner = state.fsdp_config['load_planner']
- _validate_load_planner(load_planner)
-
- from torch.distributed import checkpoint as dist_cp
- from torch.distributed.checkpoint.metadata import Metadata
- from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
- from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner
-
- # This function is used so we can figure out which ranks need to load saved rngs and which can just make their own.
- def _get_num_ranks_that_saved_rng(metadata: Metadata):
- rng_inds = []
- for field_name, field_value in metadata.planner_data.items():
- if 'rng' in field_name:
- _, rng_rank_index, _ = field_value
- rng_inds.append(rng_rank_index)
- rng_inds = set(rng_inds)
- return len(rng_inds)
-
- class FileSystemReaderWithValidation(dist_cp.FileSystemReader):
- """FileSystemReader that validates checkpoint files prior to reading."""
-
- def __init__(self, path: str):
- if _get_checkpoint_validation_function() is None:
- log.info('No checkpoint validation function found when loading sharded checkpoints.')
- super().__init__(path)
-
- def read_data(self, plan: LoadPlan, planner: LoadPlanner):
- """Reads data file.
-
- Raises:
- ValueError if the data file is invalid.
- """
- for read_item in plan.items:
- data_path = self.path / self.storage_data[read_item.storage_index].relative_path
- _ensure_valid_checkpoint(data_path)
- return super().read_data(plan, planner)
-
- def read_metadata(self) -> Metadata:
- """Reads metadata file.
-
- Raises:
- ValueError if the metadata file is invalid.
- """
- metadata_file_path = self.path / '.metadata'
- _ensure_valid_checkpoint(metadata_file_path)
- return super().read_metadata()
-
- # A subclass of FileSystemReaderWithValidation that downloads files from the object store before reading them from the local filesystem.
- class DistCPObjectStoreReader(FileSystemReaderWithValidation):
-
- def __init__(self, source_path: str, destination_path: str, object_store):
- self.source_path = source_path
- self.destination_path = destination_path
- self.object_store = object_store
-
- # Download metadata file.
- Path(self.destination_path).mkdir(parents=True, exist_ok=True)
- metadata_destination = os.path.join(self.destination_path, '.metadata')
- if dist.get_local_rank() == 0:
- object_store.download_object(object_name=str(Path(source_path) / Path('.metadata')),
- filename=metadata_destination)
- dist.barrier()
-
- # FileSystemReader takes in a root directory in its constructor, which is the dir where
- # the metadata is expected to be stored. Also, this is parent directory for any shard file relative paths
- # specified in the metadata file.
- super().__init__(destination_path)
-
- def read_data(self, plan: LoadPlan, planner: LoadPlanner):
- # 1. Download to the destination all files that this rank is responsible for.
- for plan_item in plan.items:
- # Each plan item has a storage index which points to the relative path of the shard file at save time.
- relative_file_path = self.storage_data[plan_item.storage_index].relative_path
- # Download the shard file to the relative path it's associated to and save that relative path
- # to the root directory specified to the FileSystem reader constructor.
- file_destination = str(Path(self.destination_path) / Path(relative_file_path))
- # The file could have already been downloaded as diffeent plan items can point to same file.
- if not os.path.exists(file_destination):
- self.object_store.download_object(object_name=str(
- Path(self.source_path) / Path(relative_file_path)),
- filename=file_destination)
-
- # 2. Wait for all ranks to finish.
- dist.barrier()
-
- # 3. Piggyback off of the FileSystemReader to read all the files now that they are downloaded.
- return super().read_data(plan, planner)
# Check to make sure source_path is a directory.
if object_store is None:
@@ -486,39 +581,58 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
# Get the tempfile made on local rank 0.
local_rank0_index = dist.get_global_rank() - dist.get_local_rank()
rank0_download_tempdir = str(dist.all_gather_object(temp_download_dir)[local_rank0_index])
- storage_reader = DistCPObjectStoreReader(source_path=source_path,
- destination_path=str(
- Path(rank0_download_tempdir) / Path('checkpoints')),
- object_store=object_store)
+ storage_reader = DistCPObjectStoreReader(
+ source_path=source_path,
+ destination_path=str(Path(rank0_download_tempdir) / Path('checkpoints')),
+ object_store=object_store,
+ device_mesh=state.fsdp_device_mesh,
+ )
else:
storage_reader = FileSystemReaderWithValidation(source_path)
# We need no_grad because we overwrite tensor values with set_() when we do elastic loading and we don't want the set_ op recorded in the computation graph.
with torch.no_grad():
# 1. Load model and metadata first
- model_state_dict = None
if load_weights_only:
- model_state_dict = {'state': {'model': state.get_model_state_dict()}}
+ state_dict: Dict[str, Any] = {'state': {'model': state.get_model_state_dict()}}
else:
cur_state_dict = state.state_dict()
- cur_state_dict.pop('optimizers')
- model_state_dict = {'state': cur_state_dict}
+ # For older versions of torch, we load optimizer separately.
+ if version.parse(torch.__version__) < version.parse('2.2.9'):
+ cur_state_dict.pop('optimizers')
+ num_rng_ranks = _get_num_ranks_that_saved_rng(storage_reader.read_metadata())
+ state_dict: Dict[str, Any] = {
+ 'state': cur_state_dict,
+ 'rng': reproducibility.get_rng_state()[:num_rng_ranks],
+ }
if ignore_keys:
# Filter provided list of key paths
if not callable(ignore_keys):
ignore_keys = glob_filter(ignore_keys)
# Call function to modify state_dict
- ignore_keys(model_state_dict)
-
- dist_cp.load_state_dict(
- state_dict=model_state_dict,
- storage_reader=storage_reader,
- planner=load_planner,
- )
+ ignore_keys(state_dict)
+ # Ensure state exists
+ state_dict['state'] = state_dict.get('state', {})
+
+ if version.parse(torch.__version__) > version.parse('2.2.9'):
+ dist_cp.load( # type: ignore
+ state_dict=state_dict,
+ storage_reader=storage_reader,
+ planner=state.fsdp_config['load_planner'],
+ no_dist=(not dist.is_initialized()),
+ )
+ else:
+ dist_cp.load_state_dict(
+ state_dict=state_dict,
+ storage_reader=storage_reader,
+ planner=state.fsdp_config['load_planner'],
+ no_dist=(not dist.is_initialized()),
+ )
+ log.info(f'Loaded state dict')
state.load_state_dict(
- model_state_dict['state'],
+ state_dict['state'],
logger,
strict=strict_model_weights,
exclude_algorithms=exclude_algorithms,
@@ -526,32 +640,14 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
)
# 2. Optionally load optimizer
- if not load_weights_only:
+ # if we are using later than 2.2.9 then optimizer will already be loaded
+ if version.parse(torch.__version__) < version.parse('2.2.9') and not load_weights_only:
optim_state = load_sharded_optimizer_state_dict(model_state_dict=state.state_dict()['model'],
optimizer_key='optimizers',
storage_reader=storage_reader)
- state.load_optim_state(optim_state)
-
- # 3. Optionally load RNG
- rng_state_dicts = reproducibility.get_rng_state()
- if not load_weights_only:
- # If we are resuming on more ranks than were used at save time we only want to load in rngs for those ranks
- num_ranks_that_saved_rng = _get_num_ranks_that_saved_rng(storage_reader.read_metadata())
- rng_state_dicts_load = {}
- rng_state_dicts_load['rng'] = rng_state_dicts[:num_ranks_that_saved_rng] if len(
- rng_state_dicts) > num_ranks_that_saved_rng else rng_state_dicts
- dist_cp.load_state_dict(
- state_dict=rng_state_dicts_load,
- storage_reader=storage_reader,
- planner=load_planner,
- )
- # We also want to append newly generated rng states for the ranks that don't have an rng state to load in
- # if we are resuming on more ranks than were used at save time.
- if len(rng_state_dicts) > num_ranks_that_saved_rng:
- rng_state_dicts_load['rng'].extend(rng_state_dicts[num_ranks_that_saved_rng:])
- rng_state_dicts = rng_state_dicts_load['rng']
+ state._legacy_load_optim_state(optim_state)
- return rng_state_dicts
+ return state_dict.get('rng', None)
def _get_local_rank_zero_path(path: Optional[str]) -> str:
@@ -599,7 +695,7 @@ def download_checkpoint(path: str,
checkpoint_is_sharded = fsdp_sharded_state_dict_enabled or deepspeed_sharded_checkpoint
try:
if not checkpoint_is_sharded and dist.get_local_rank() == 0:
- # if the checkpoint is not sharded, then local rank 0 on each node needs to download the
+ # If the checkpoint is not sharded, then local rank 0 on each node needs to download the
# global rank 0 checkpoint
path = _format_path_with_rank_zero(path)
get_file(destination=rank_zero_checkpoint_filepath,
@@ -616,18 +712,18 @@ def download_checkpoint(path: str,
# or could not be downloaded
raise RuntimeError(f'Checkpoint {path} does not exist')
elif checkpoint_is_sharded:
- # if the checkpoint is sharded, then every rank needs to download its own checkpoint
+ # If the checkpoint is sharded, then every rank needs to download its own checkpoint
+ path = _format_path_with_current_rank(path)
try:
get_file(destination=rank_n_checkpoint_filepath,
- path=_format_path_with_current_rank(path),
+ path=path,
object_store=object_store,
progress_bar=progress_bar)
except FileNotFoundError as e:
raise FileNotFoundError(
- (f'Checkpoint {_format_path_with_current_rank(path)} does not exist, '
- f'but is required for sharded checkpointing on rank {dist.get_global_rank()}. '
- 'Please ensure that the checkpoint exists and your load_path was specified as a format string'
- 'with the {rank} argument.')) from e
+ (f'Checkpoint {path} does not exist, but is required for sharded checkpointing '
+ f'on rank {dist.get_global_rank()}. Please ensure that the checkpoint exists '
+ 'and your load_path was specified as a format string with the {rank} argument.')) from e
if extracted_checkpoint_folder is not None:
try:
@@ -677,14 +773,25 @@ def _flatten_keys(obj: Any, paths: list[str], existing_path: str):
def _remove_paths(obj: Union[list, dict[str, Any]], exclude_paths: list[list[str]]):
+ # Build str(key) to key map to undo cast from glob filtering. Despite typing, some state_dict
+ # keys are not strings, so we need to cast them back to their original type.
+ str_key_to_key = {}
+ if isinstance(obj, dict):
+ for key in obj.keys():
+ str_key_to_key[str(key)] = key
+
# First determine the keys which will be recursed on and which will be removed entirely
# Group the `exclude_paths` by the key
keys_to_recurse = {}
keys_to_remove = []
for exclude_path_parts in exclude_paths:
key = exclude_path_parts[0]
+ # Cast list indices to int
if isinstance(obj, list):
key = int(key)
+ # Un-str dict keys if necessary
+ if key in str_key_to_key:
+ key = str_key_to_key[key]
if len(exclude_path_parts) == 1:
keys_to_remove.append(key)
else:
@@ -720,51 +827,17 @@ def filter_func(state_dict: dict) -> None:
f'No parts from loaded checkpoint state_dict were ignored by load_ignore_key {exclude_glob}')
filtered_paths.extend(filtered_paths_from_glob)
filtered_paths = list(set(filtered_paths))
- filtered_paths_str = ', '.join(filtered_paths)
if filtered_paths:
+ filtered_paths_str = ', '.join(filtered_paths)
log.info(f'Ignoring the following paths from the loaded checkpoint state_dict: {filtered_paths_str}')
# Loop through all paths to exclude
- paths_to_remove = [path.split('/') for path in filtered_paths]
+ paths_to_remove = [path.split('/') for path in filtered_paths if len(path) > 0]
_remove_paths(state_dict, paths_to_remove)
return filter_func
-def _validate_save_planner(save_planner: Optional[Any]) -> None:
- """Checks that ``save_planner`` is an instance of a :class:`~torch.distributed.checkpoint.planner.SavePlanner`.
-
- TODO(GRT-2456): Remove validation once we deprecate torch 1.13 and can use
- type hints.
-
- Raises:
- ValueError: If ``save_planner`` is not a
- :class:`~torch.distributed.checkpoint.planner.SavePlanner`.
- """
- from torch.distributed.checkpoint.planner import SavePlanner
-
- if save_planner is not None and not isinstance(save_planner, SavePlanner):
- raise ValueError((f'save_planner {type(save_planner)} is not a '
- 'torch.distributed.checkpoint.planner.SavePlanner'))
-
-
-def _validate_load_planner(load_planner: Optional[Any]) -> None:
- """Checks that ``load_planner`` is an instance of a :class:`~torch.distributed.checkpoint.planner.LoadPlanner`.
-
- TODO(GRT-2456): Remove validation once we deprecate torch 1.13 and can use
- type hints.
-
- Raises:
- ValueError: If ``load_planner`` is not a
- :class:`~torch.distributed.checkpoint.planner.LoadPlanner`.
- """
- from torch.distributed.checkpoint.planner import LoadPlanner
-
- if load_planner is not None and not isinstance(load_planner, LoadPlanner):
- raise ValueError((f'load_planner {type(load_planner)} is not a '
- 'torch.distributed.checkpoint.planner.LoadPlanner'))
-
-
def safe_torch_load(
composer_states_filepath: Union[Path, str],
map_location: str = 'cpu',
@@ -840,6 +913,8 @@ def _restore_checkpoint(
ignore_keys = glob_filter(ignore_keys)
# Call function to modify state_dict
ignore_keys(state_dict)
+ # Ensure state exists
+ state_dict['state'] = state_dict.get('state', {})
log.debug(f"Loaded checkpoint with keys {state_dict.keys()} and state keys {state_dict['state'].keys()}")
if is_model_deepspeed(state.model):
@@ -876,11 +951,39 @@ def _restore_checkpoint(
return state_dict.get('rng', None)
-def save_checkpoint(
+def get_save_filename(
state: State,
filename: str = 'ep{epoch}-ba{batch}-rank{rank}',
+) -> str:
+ """Gets full filename of save filename.
+
+ Args:
+ state (State): The :class:`~composer.core.State` to load the checkpoint into.
+ filename (filename): The name of the save file.
+
+ Returns:
+ Full filename of save file.
+ """
+ if not state.fsdp_sharded_state_dict_enabled:
+ is_deepspeed = is_model_deepspeed(state.model)
+ return PartialFilePath(filename).format(state, is_deepspeed)
+
+ # Sharded checkpoints get their own little folder.
+ assert state.sharded_ckpt_prefix_dir is not None
+ save_dirpath = Path(Path(filename).parent) / Path(state.sharded_ckpt_prefix_dir)
+ save_dirpath = format_name_with_dist_and_time(str(save_dirpath), state.run_name, state.timestamp)
+ # New name is now Trainer.save_folder / sharded_ckpt_prefix_dir / __{dist.get_global_rank()}_0.distcpâ
+ # e.g. path/to/my/checkpoints/ep1-ba2/__1_0.distcp
+ ckpt_filename = _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
+ return str(Path(save_dirpath) / Path(ckpt_filename))
+
+
+def _save_checkpoint(
+ state: State,
+ save_filename: str,
*,
weights_only: bool = False,
+ ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None,
) -> Union[str, None]: # noqa: D103
is_deepspeed = is_model_deepspeed(state.model)
@@ -900,38 +1003,36 @@ def save_checkpoint(
'rng': reproducibility.get_rng_state(),
}
- log.debug('State dict created.')
+ if ignore_keys:
+ # Filter provided list of key paths
+ if not callable(ignore_keys):
+ ignore_keys = glob_filter(ignore_keys)
+ # Call function to modify state_dict
+ ignore_keys(state_dict)
+ # Ensure state exists
+ state_dict['state'] = state_dict.get('state', {})
- # Sharded checkpoints get their own little folder.
if state.fsdp_sharded_state_dict_enabled:
- # To load optimizer states with torch 2.0, the optimizer state must be at the top
+ # To load optimizer states with 2.0 <= torch < 2.2.9 , the optimizer state must be at the top
# level of the state dict because the load_sharded_optimizer_state_dict function
# requires a top level state dict key for the optimizer.
# See https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/checkpoint/optimizer.py#L271
# for more info.
- if using_torch_2():
+ if version.parse(torch.__version__) < version.parse('2.2.9'):
if not weights_only:
state_dict['optimizers'] = state_dict['state'].pop('optimizers')
-
- # Specify save directory path and save_f
- assert state.sharded_ckpt_prefix_dir is not None
- save_dirpath = Path(Path(filename).parent) / Path(state.sharded_ckpt_prefix_dir)
- save_dirpath = format_name_with_dist_and_time(str(save_dirpath), state.run_name, state.timestamp)
- # New name is now Trainer.save_folder / sharded_ckpt_prefix_dir / __{dist.get_global_rank()}_0.distcpâ if torch > 2
- # else Trainer.save_folder / sharded_ckpt_prefix_dir / ba{batch}_rank{dist.get_global_rank()}.ptâ
- # e.g. path/to/my/checkpoints/ep1-ba2/__1_0.distcp if torch >2 else its path/to/my/checkpoints/ep1-ba2/b2-rank1.pt
- ckpt_filename = _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME if using_torch_2() else format_name_with_dist_and_time(
- Path(filename).name, state.run_name, state.timestamp)
- save_filename = str(Path(save_dirpath) / Path(ckpt_filename))
- else:
- save_filename = PartialFilePath(filename).format(state, is_deepspeed)
+ log.debug('State dict created.')
dirname = os.path.dirname(save_filename)
if dirname:
os.makedirs(dirname, exist_ok=True)
+ # Only some ranks are meant to save checkpoint and produce a file
+ expect_file = False
+
# All ranks save for deepspeed
if is_deepspeed:
+ expect_file = True
log.debug('Saving deepspeed checkpoints to %s...', save_filename)
if dist.get_global_rank() == 0:
with open(save_filename, 'wb') as f:
@@ -941,24 +1042,43 @@ def save_checkpoint(
_save_deepspeed_model(state.deepspeed_model, save_filename)
- # Sharded checkpointing for torch >=2.0 uses the torch.distributed.checkpoint module.
+ # Sharded checkpointing
elif state.fsdp_elastic_sharded_enabled:
if state.fsdp_config is None:
raise ValueError('Saving a sharded checkpoint requires passing an FSDP config to Trainer.')
- save_planner = state.fsdp_config['save_planner']
- _validate_save_planner(save_planner)
-
- import torch.distributed.checkpoint as dist_cp
- log.debug('Saving sharded checkpoints to %s...', save_filename)
- dist_cp.save_state_dict(
- state_dict=state_dict,
- storage_writer=dist_cp.FileSystemWriter(dirname),
- planner=save_planner,
- )
+ log.debug(f'Saving sharded checkpoints to {save_filename}...')
+ process_group = None
+ device_mesh = state.fsdp_device_mesh
+ if device_mesh is not None and device_mesh.ndim == 2:
+ # If hybrid shard, only rank in first replica saves
+ expect_file = device_mesh.get_local_rank(mesh_dim=0) == 0
+ if expect_file:
+ process_group = device_mesh.get_group(1) # Shard process_group for first replica
+ log.debug(f'Saving on global_rank={dist.get_global_rank()}, {expect_file=}')
+ else:
+ expect_file = True
+
+ if expect_file:
+ if version.parse(torch.__version__) > version.parse('2.2.9'):
+ dist_cp.save( # type: ignore
+ state_dict=state_dict,
+ storage_writer=dist_cp.FileSystemWriter(dirname),
+ planner=state.fsdp_config['save_planner'],
+ process_group=process_group,
+ )
+ else:
+ dist_cp.save_state_dict(
+ state_dict=state_dict,
+ storage_writer=dist_cp.FileSystemWriter(dirname),
+ planner=state.fsdp_config['save_planner'],
+ process_group=process_group,
+ )
+ log.debug('Finished pytorch save state dict')
# Only rank 0 saves the state_dict unless you are using sharded checkpointing with torch <2.0
elif dist.get_global_rank() == 0 or state.fsdp_sharded_state_dict_enabled:
+ expect_file = True
log_msg = f'Saving sharded checkpoints to {save_filename}...' if state.fsdp_sharded_state_dict_enabled else f'Saving monolithic checkpoint to {save_filename}'
with open(save_filename, 'wb') as f:
log.debug(log_msg)
@@ -974,7 +1094,7 @@ def save_checkpoint(
dist.barrier() # ensure all ranks saved their files
- if dist.get_global_rank() == 0 or is_deepspeed or state.fsdp_sharded_state_dict_enabled:
+ if expect_file:
assert os.path.exists(save_filename), 'Expected file to have been saved.'
return save_filename
else:
@@ -1014,6 +1134,17 @@ def _save_deepspeed_model(model, filename: str):
tar.add(tmpdir, arcname='')
+def save_checkpoint(
+ state: State,
+ filename: str = 'ep{epoch}-ba{batch}-rank{rank}',
+ *,
+ weights_only: bool = False,
+ ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None,
+) -> Union[str, None]: # noqa: D103
+ save_filename = get_save_filename(state, filename)
+ return _save_checkpoint(state, save_filename, weights_only=weights_only, ignore_keys=ignore_keys)
+
+
save_checkpoint.__doc__ = f"""Checkpoint the training ``state``.
Args:
diff --git a/composer/utils/collect_env.py b/composer/utils/collect_env.py
index 2926c54a6f..02e74af8f9 100644
--- a/composer/utils/collect_env.py
+++ b/composer/utils/collect_env.py
@@ -378,7 +378,6 @@ def print_env(file: Optional[TextIO] = None) -> None:
[pip3] torch-optimizer==0.1.0
[pip3] torchmetrics==0.7.3
[pip3] torchvision==0.10.1+cu111
- [pip3] vit-pytorch==0.27.0
[conda] Could not collect
diff --git a/composer/utils/dist.py b/composer/utils/dist.py
index 1b59bff1d4..65edb5e80c 100644
--- a/composer/utils/dist.py
+++ b/composer/utils/dist.py
@@ -253,15 +253,19 @@ def get_node_rank() -> int:
return _get_distributed_config_var(env_var='NODE_RANK', default=0, human_name='node rank')
-def barrier() -> None:
+def barrier(group=None) -> None:
"""Synchronizes all processes.
This function blocks until all processes reach this function.
.. seealso:: :func:`torch.distributed.barrier`
+
+ Args:
+ group (ProcessGroup, optional): The process group to work on. If ``None``,
+ the default process group will be used. Default is ``None``.
"""
if dist.is_available() and dist.is_initialized():
- dist.barrier()
+ dist.barrier(group=group)
return
world_size = get_world_size()
if world_size == 1:
@@ -276,6 +280,7 @@ def barrier() -> None:
def all_reduce(
tensor: torch.Tensor,
reduce_operation: str = 'SUM',
+ group=None,
) -> None:
"""Reduce a ``tensor`` by applying the ``reduce_operation``.
@@ -289,6 +294,8 @@ def all_reduce(
op (optional): One of the values from
``torch.distributed.ReduceOp``
enum. Specifies an operation used for element-wise reductions.
+ group (ProcessGroup, optional): The process group to work on. If ``None``,
+ the default process group will be used. Default is ``None``.
Args:
tensor (torch.Tensor): Tensor to reduce. The function operates in-place.
reduce_operation (str, optional): The reduction operation (default: ``SUM``).
@@ -307,7 +314,7 @@ def all_reduce(
"""
if dist.is_available() and dist.is_initialized():
reduce_op = getattr(dist.ReduceOp, reduce_operation.upper())
- dist.all_reduce(tensor, op=reduce_op)
+ dist.all_reduce(tensor, op=reduce_op, group=group)
return
world_size = get_world_size()
if world_size == 1:
@@ -319,7 +326,7 @@ def all_reduce(
'`composer.utils.dist.initialize_dist` has been called first.')
-def broadcast(tensor: torch.Tensor, src: int) -> None:
+def broadcast(tensor: torch.Tensor, src: int, group=None) -> None:
"""Broadcasts the tensor to the whole group.
``tensor`` must have the same number of elements in all processes participating in the collective.
@@ -329,9 +336,11 @@ def broadcast(tensor: torch.Tensor, src: int) -> None:
tensor (torch.Tensor): Data to be sent if ``src`` is the rank of current process,
and tensor to be used to save received data otherwise.
src (int): Source rank
+ group (ProcessGroup, optional): The process group to work on. If ``None``,
+ the default process group will be used. Default is ``None``.
"""
if dist.is_available() and dist.is_initialized():
- dist.broadcast(tensor, src)
+ dist.broadcast(tensor, src=src, group=group)
return
world_size = get_world_size()
if world_size == 1:
@@ -343,7 +352,7 @@ def broadcast(tensor: torch.Tensor, src: int) -> None:
'`composer.utils.dist.initialize_dist` has been called first.')
-def broadcast_object_list(object_list: List[Any], src: int = 0) -> None:
+def broadcast_object_list(object_list: List[Any], src: int = 0, group=None) -> None:
"""Broadcasts picklable objects in ``object_list`` to the whole group.
Similar to :func:`broadcast`, but Python objects can be passed in.
@@ -356,12 +365,14 @@ def broadcast_object_list(object_list: List[Any], src: int = 0) -> None:
Each object must be picklable. Only objects on the ``src`` rank will be broadcast,
but each rank must provide lists of equal sizes.
src (int, optional): Source rank (default: ``0``)
+ group (ProcessGroup, optional): The process group to work on. If ``None``,
+ the default process group will be used. Default is ``None``.
Returns:
None: ``object_list`` will be modified in-place and set to values of ``object_list`` from the ``src`` rank.
"""
if dist.is_available() and dist.is_initialized():
- dist.broadcast_object_list(object_list, src)
+ dist.broadcast_object_list(object_list, src=src, group=group)
# torch.distributed will replace the None's in obj_gather_list with the gathered objects on rank 0
# or will just be None on non-rank-0
return
@@ -375,20 +386,22 @@ def broadcast_object_list(object_list: List[Any], src: int = 0) -> None:
'`composer.utils.dist.initialize_dist` has been called first.')
-def all_gather(tensor: torch.Tensor) -> Sequence[torch.Tensor]:
+def all_gather(tensor: torch.Tensor, group=None) -> Sequence[torch.Tensor]:
"""Collects a :class:`~torch.Tensor` from each rank.
.. seealso:: :func:`torch.distributed.all_gather`
Args:
tensor (torch.Tensor): Tensor from each rank to be gathered.
+ group (ProcessGroup, optional): The process group to work on. If ``None``,
+ the default process group will be used. Default is ``None``.
Returns:
Sequence[Tensor]: A sequence of tensors indexed by rank.
"""
if dist.is_available() and dist.is_initialized():
obj_gather_list = [torch.zeros_like(tensor) for _ in range(get_world_size())]
- dist.all_gather(obj_gather_list, tensor)
+ dist.all_gather(obj_gather_list, tensor, group=group)
return obj_gather_list
world_size = get_world_size()
if world_size == 1:
@@ -400,13 +413,15 @@ def all_gather(tensor: torch.Tensor) -> Sequence[torch.Tensor]:
'`composer.utils.dist.initialize_dist` has been called first.')
-def all_gather_object(obj: TObj) -> List[TObj]:
+def all_gather_object(obj: TObj, group=None) -> List[TObj]:
"""Collect a pickleable object from each rank and return a list of these objects indexed by rank.
.. seealso:: :func:`torch.distributed.all_gather_object`
Args:
obj (TObj): Object to be gathered.
+ group (ProcessGroup, optional): The process group to work on. If ``None``,
+ the default process group will be used. Default is ``None``.
Returns:
List[TObj]: A list of objects indexed by rank.
@@ -414,9 +429,9 @@ def all_gather_object(obj: TObj) -> List[TObj]:
if dist.is_available() and dist.is_initialized():
obj_gather_list = [None for _ in range(get_world_size())]
if is_hpu_installed():
- all_gather_object_list_hpu(obj_gather_list, obj)
+ all_gather_object_list_hpu(obj_gather_list, obj, group=group)
else:
- dist.all_gather_object(obj_gather_list, obj)
+ dist.all_gather_object(obj_gather_list, obj, group=group)
# torch.distributed will replace the None's in obj_gather_list with the gathered objects on rank 0
# or will just be None on non-rank-0
return cast(List[TObj], obj_gather_list)
diff --git a/composer/utils/eval_client/local_eval_client.py b/composer/utils/eval_client/local_eval_client.py
index 357660b284..710a478473 100644
--- a/composer/utils/eval_client/local_eval_client.py
+++ b/composer/utils/eval_client/local_eval_client.py
@@ -38,7 +38,7 @@ def invoke_helper(self, payload: Dict[str, str]) -> bool:
p.start()
p.join(TIMEOUT) # wait for timeout to terminate
p.terminate()
- return bool(ret.value)
+ return bool(ret.value) # pyright: ignore[reportGeneralTypeIssues]
def update_offline_helper(self, code_gen: str, test_input: str, test_output: str, entry_point: str, language: str,
val: multiprocessing.Value): # type: ignore
diff --git a/composer/utils/eval_client/mosaicml_lambda_eval_client.py b/composer/utils/eval_client/mosaicml_lambda_eval_client.py
index fabb6b32be..cc9ea74714 100644
--- a/composer/utils/eval_client/mosaicml_lambda_eval_client.py
+++ b/composer/utils/eval_client/mosaicml_lambda_eval_client.py
@@ -46,7 +46,7 @@ def invoke(self, payload: List[List[List[Dict[str, str]]]]) -> List[List[List[bo
ret_helper = [False] * len(test_cases)
for i in range(self.num_retries):
try:
- ret_helper = mcli.get_code_eval_output(test_cases).data
+ ret_helper = mcli.get_code_eval_output(test_cases).data # pyright: ignore[reportGeneralTypeIssues]
break
except mcli.MAPIException as e:
if e.status >= 500:
diff --git a/composer/utils/file_helpers.py b/composer/utils/file_helpers.py
index a3d421259b..7c75b4633e 100644
--- a/composer/utils/file_helpers.py
+++ b/composer/utils/file_helpers.py
@@ -20,7 +20,10 @@
from composer.utils import dist
from composer.utils.iter_helpers import iterate_with_callback
-from composer.utils.object_store import GCSObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore, UCObjectStore
+from composer.utils.misc import partial_format
+from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore,
+ OCIObjectStore, S3ObjectStore, UCObjectStore)
+from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX
if TYPE_CHECKING:
from composer.core import Timestamp
@@ -29,9 +32,16 @@
log = logging.getLogger(__name__)
__all__ = [
- 'get_file', 'ensure_folder_is_empty', 'ensure_folder_has_no_conflicting_files', 'format_name_with_dist',
- 'format_name_with_dist_and_time', 'is_tar', 'create_symlink_file', 'maybe_create_object_store_from_uri',
- 'maybe_create_remote_uploader_downloader_from_uri', 'parse_uri'
+ 'get_file',
+ 'ensure_folder_is_empty',
+ 'ensure_folder_has_no_conflicting_files',
+ 'format_name_with_dist',
+ 'format_name_with_dist_and_time',
+ 'is_tar',
+ 'create_symlink_file',
+ 'maybe_create_object_store_from_uri',
+ 'maybe_create_remote_uploader_downloader_from_uri',
+ 'parse_uri',
]
@@ -166,7 +176,8 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path]
def format_name_with_dist(format_str: str, run_name: str, **extra_format_kwargs: object): # noqa: D103
- formatted_str = format_str.format(
+ formatted_str = partial_format(
+ format_str,
run_name=run_name,
**_get_dist_config(strict=False),
**extra_format_kwargs,
@@ -259,7 +270,8 @@ def format_name_with_dist_and_time(
timestamp: Timestamp,
**extra_format_kwargs: object,
): # noqa: D103
- formatted_str = format_str.format(
+ formatted_str = partial_format(
+ format_str,
run_name=run_name,
epoch=int(timestamp.epoch),
batch=int(timestamp.batch),
@@ -314,6 +326,7 @@ def parse_uri(uri: str) -> Tuple[str, str, str]:
Tuple[str, str, str]: A tuple containing the backend (e.g. s3), bucket name, and path.
Backend name will be empty string if the input is a local path
"""
+ uri = uri.replace('AZURE_BLOBS', 'azure') # urlparse does not support _ in scheme
parse_result = urlparse(uri)
backend, net_loc, path = parse_result.scheme, parse_result.netloc, parse_result.path
bucket_name = net_loc if '@' not in net_loc else net_loc.split('@')[0]
@@ -349,10 +362,36 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
return GCSObjectStore(bucket=bucket_name)
elif backend == 'oci':
return OCIObjectStore(bucket=bucket_name)
+ elif backend == 'azure':
+ return LibcloudObjectStore(
+ provider='AZURE_BLOBS',
+ container=bucket_name,
+ key_environ='AZURE_ACCOUNT_NAME',
+ secret_environ='AZURE_ACCOUNT_ACCESS_KEY',
+ )
elif backend == 'dbfs':
- # validate if the path conforms to the requirements for UC volume paths
- UCObjectStore.validate_path(path)
- return UCObjectStore(path=path)
+ if path.startswith(MLFLOW_DBFS_PATH_PREFIX):
+ store = None
+ if dist.get_global_rank() == 0:
+ store = MLFlowObjectStore(path)
+
+ # The path may have had placeholders, so update it with the experiment/run IDs initialized by the store
+ path = store.get_dbfs_path(path)
+
+ # Broadcast the rank 0 updated path to all ranks for their own object stores
+ path_list = [path]
+ dist.broadcast_object_list(path_list, src=0)
+ path = path_list[0]
+
+ # Create the object store for all other ranks
+ if dist.get_global_rank() != 0:
+ store = MLFlowObjectStore(path)
+
+ return store
+ else:
+ # validate if the path conforms to the requirements for UC volume paths
+ UCObjectStore.validate_path(path)
+ return UCObjectStore(path=path)
else:
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported object stores')
@@ -387,14 +426,21 @@ def maybe_create_remote_uploader_downloader_from_uri(
return None
if backend in ['s3', 'oci', 'gs']:
return RemoteUploaderDownloader(bucket_uri=f'{backend}://{bucket_name}')
-
+ elif backend == 'azure':
+ return RemoteUploaderDownloader(
+ bucket_uri=f'libcloud://{bucket_name}',
+ backend_kwargs={
+ 'provider': 'AZURE_BLOBS',
+ 'container': bucket_name,
+ 'key_environ': 'AZURE_ACCOUNT_NAME',
+ 'secret_environ': 'AZURE_ACCOUNT_ACCESS_KEY',
+ },
+ )
+ elif backend == 'dbfs':
+ return RemoteUploaderDownloader(bucket_uri=uri, backend_kwargs={'path': path})
elif backend == 'wandb':
raise NotImplementedError(f'There is no implementation for WandB via URI. Please use '
'WandBLogger with log_artifacts set to True')
- elif backend == 'dbfs':
- # validate if the path conforms to the requirements for UC volume paths
- UCObjectStore.validate_path(path)
- return RemoteUploaderDownloader(bucket_uri=uri, backend_kwargs={'path': path})
else:
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported RemoteUploaderDownloader object stores')
diff --git a/composer/utils/fx_utils.py b/composer/utils/fx_utils.py
index 2b1ff41b3e..9162b84878 100644
--- a/composer/utils/fx_utils.py
+++ b/composer/utils/fx_utils.py
@@ -234,6 +234,7 @@ def apply_stochastic_residual(gm: GraphModule, drop_rate: float = 0.2) -> Tuple[
f'Input to apply_stochastic_residual should be an instance of GraphModule. Received {type(gm)}')
all_tags, count = _tag_residual_nodes(gm)
split_gm = split_by_tags(gm, all_tags)
+ assert isinstance(split_gm, GraphModule)
for node in split_gm.graph.nodes:
if node.op != 'call_module':
continue
diff --git a/composer/utils/iter_helpers.py b/composer/utils/iter_helpers.py
index 1338dbf872..7c256fd21e 100644
--- a/composer/utils/iter_helpers.py
+++ b/composer/utils/iter_helpers.py
@@ -6,11 +6,14 @@
# All methods signatures must be defined in there.
"""Utilities for iterating over collections."""
+from __future__ import annotations
+
import collections.abc
import io
+from typing import Any
-def map_collection(collection, map_fn):
+def map_collection(collection, map_fn) -> Any:
"""Applies ``map_fn`` on each element in ``collection``.
* If ``collection`` is a tuple or list of elements, ``map_fn`` is applied on each element,
@@ -37,7 +40,7 @@ def map_collection(collection, map_fn):
return map_fn(collection)
-def ensure_tuple(x):
+def ensure_tuple(x) -> tuple[Any, ...]:
"""Converts ``x`` into a tuple.
* If ``x`` is ``None``, then ``tuple()`` is returned.
diff --git a/composer/utils/misc.py b/composer/utils/misc.py
index 76573f8901..e5fa5942ae 100644
--- a/composer/utils/misc.py
+++ b/composer/utils/misc.py
@@ -9,7 +9,6 @@
from typing import TYPE_CHECKING, Callable, Optional, Set, Type, Union
import torch
-from packaging import version
from torch.nn.parallel import DistributedDataParallel
if TYPE_CHECKING:
@@ -52,21 +51,21 @@ def create_interval_scheduler(interval: Union[str, int, 'Time'],
if final_events is None:
final_events = {Event.BATCH_CHECKPOINT, Event.EPOCH_CHECKPOINT}
- interval = Time.from_input(interval, TimeUnit.EPOCH)
- if interval.unit == TimeUnit.EPOCH:
+ time_interval: Time = Time.from_input(interval, TimeUnit.EPOCH)
+ if time_interval.unit == TimeUnit.EPOCH:
interval_event = Event.EPOCH_CHECKPOINT if checkpoint_events else Event.EPOCH_END
- elif interval.unit in {TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE, TimeUnit.DURATION}:
+ elif time_interval.unit in {TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE, TimeUnit.DURATION}:
interval_event = Event.BATCH_CHECKPOINT if checkpoint_events else Event.BATCH_END
else:
raise NotImplementedError(
- f'Unknown interval: {interval.unit}. Must be TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.'
+ f'Unknown interval: {time_interval.unit}. Must be TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.'
)
last_batch_seen = -1
def check_interval(state: State, event: Event):
# `TimeUnit.Duration` value is a float from `[0.0, 1.0)`
- if not interval.unit == TimeUnit.DURATION and int(interval) <= 0:
+ if not time_interval.unit == TimeUnit.DURATION and int(time_interval) <= 0:
return False
nonlocal last_batch_seen # required to use the last_batch_seen from the outer function scope
@@ -81,25 +80,25 @@ def check_interval(state: State, event: Event):
if include_end_of_training and event in final_events and elapsed_duration >= 1.0 and state.timestamp.batch != last_batch_seen:
return True
- if interval.unit in {TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE}:
- previous_count = state.previous_timestamp.get(interval.unit)
- count = state.timestamp.get(interval.unit)
+ if time_interval.unit in {TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE}:
+ previous_count = state.previous_timestamp.get(time_interval.unit)
+ count = state.timestamp.get(time_interval.unit)
# If the eval_interval is a duration, we will track progress in terms of the unit of max_duration
- elif interval.unit == TimeUnit.DURATION:
+ elif time_interval.unit == TimeUnit.DURATION:
assert state.max_duration is not None
previous_count = state.previous_timestamp.get(state.max_duration.unit)
count = state.timestamp.get(state.max_duration.unit)
else:
raise NotImplementedError(
- f'Unknown interval: {interval.unit}. Must be TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.'
+ f'Unknown interval: {time_interval.unit}. Must be TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.'
)
- threshold_passed = math.floor(previous_count / interval.value) != math.floor(count / interval.value)
+ threshold_passed = math.floor(previous_count / time_interval.value) != math.floor(count / time_interval.value)
- if interval.unit != TimeUnit.DURATION and event == interval_event and threshold_passed:
+ if time_interval.unit != TimeUnit.DURATION and event == interval_event and threshold_passed:
last_batch_seen = state.timestamp.batch
return True
- elif interval.unit == TimeUnit.DURATION:
+ elif time_interval.unit == TimeUnit.DURATION:
assert state.max_duration is not None, 'max_duration should not be None'
if state.dataloader_len is None:
raise RuntimeError(
@@ -107,22 +106,22 @@ def check_interval(state: State, event: Event):
if event == interval_event:
if state.max_duration.unit == TimeUnit.EPOCH and int(state.timestamp.batch) % math.ceil(
- state.max_duration.value * float(interval) * state.dataloader_len) == 0:
+ state.max_duration.value * float(time_interval) * state.dataloader_len) == 0:
last_batch_seen = state.timestamp.batch
return True
elif state.max_duration.unit == TimeUnit.BATCH and int(state.timestamp.batch) % math.ceil(
- state.max_duration.value * interval.value) == 0:
+ state.max_duration.value * time_interval.value) == 0:
last_batch_seen = state.timestamp.batch
return True
elif state.max_duration.unit == TimeUnit.SAMPLE:
- samples_per_interval = math.ceil(state.max_duration.value * interval)
+ samples_per_interval = math.ceil(state.max_duration.value * time_interval)
threshold_passed = math.floor(previous_count / samples_per_interval) != math.floor(
count / samples_per_interval)
if threshold_passed:
last_batch_seen = state.timestamp.batch
return True
elif state.max_duration.unit == TimeUnit.TOKEN:
- tokens_per_interval = math.ceil(state.max_duration.value * interval)
+ tokens_per_interval = math.ceil(state.max_duration.value * time_interval)
threshold_passed = math.floor(previous_count / tokens_per_interval) != math.floor(
count / tokens_per_interval)
if threshold_passed:
@@ -208,19 +207,21 @@ def model_eval_mode(model: torch.nn.Module):
model.train(mode=is_training)
-def using_torch_2() -> bool:
- """Check the PyTorch version and compared it with version 2.0.0.
+def partial_format(s, *args, **kwargs) -> str:
+ """Format a string with a partial set of arguments.
- Returns:
- bool: Return True if current version is greater than or equal to 2.0.0 else False
- """
- return version.parse(torch.__version__) >= version.parse('2.0.0')
-
-
-def using_torch_2_0_1() -> bool:
- """Check the PyTorch version and compare it with version 2.0.1.
-
- Returns:
- bool: Return True if current version is greater than or equal to 2.0.1 else False
+ Since `str.format()` raises a `KeyError` if a format key is missing from the arguments, this
+ function allows for a partial set of arguments to be provided. Any missing arguments will be
+ left as-is in the string.
"""
- return version.parse(torch.__version__) >= version.parse('2.0.1')
+ max_iters = 10_000 # Just in case we get stuck in a loop somehow.
+ for _ in range(max_iters):
+ try:
+ return s.format(*args, **kwargs)
+ except IndexError as e: # Missing positional arg
+ args += ('{}',)
+ except KeyError as e: # Missing keyword arg
+ key = e.args[0]
+ kwargs[key] = '{' + key + '}'
+
+ raise RuntimeError(f'Failed to format string {s} after {max_iters} iterations.')
diff --git a/composer/utils/object_store/gcs_object_store.py b/composer/utils/object_store/gcs_object_store.py
index 1240754968..f3550d1ac9 100644
--- a/composer/utils/object_store/gcs_object_store.py
+++ b/composer/utils/object_store/gcs_object_store.py
@@ -76,7 +76,7 @@ def __init__(
self.client = Client.from_service_account_json(service_account_path)
self.use_gcs_sdk = True
try:
- self.bucket = self.client.get_bucket(self.bucket_name, timeout=60.0)
+ self.bucket = self.client.get_bucket(self.bucket_name, timeout=60)
except Exception as e:
_reraise_gcs_errors(self.get_uri(object_name=''), e)
@@ -127,12 +127,15 @@ def get_object_size(self, object_name: str) -> int:
blob_exists = Blob(bucket=self.bucket, name=key).exists(self.client)
if not blob_exists:
raise FileNotFoundError(f'{object_name} not found in {self.bucket_name}')
+ blob = None
try:
key = self.get_key(object_name)
blob = self.bucket.get_blob(key)
except Exception as e:
_reraise_gcs_errors(self.get_uri(object_name), e)
+ if blob is None or blob.size is None:
+ return -1
return blob.size # size in bytes
def upload_object(self,
@@ -223,6 +226,7 @@ def list_objects(self, prefix: Optional[str] = None) -> List[str]:
prefix = ''
prefix = self.get_key(prefix)
+ objects = []
try:
objects = self.bucket.list_blobs(prefix=prefix)
except Exception as e:
diff --git a/composer/utils/object_store/libcloud_object_store.py b/composer/utils/object_store/libcloud_object_store.py
index 9f9d9a7b91..6dec288502 100644
--- a/composer/utils/object_store/libcloud_object_store.py
+++ b/composer/utils/object_store/libcloud_object_store.py
@@ -157,7 +157,9 @@ def _get_object(self, object_name: str):
self._ensure_transient_errors_are_wrapped(e)
def get_object_size(self, object_name: str) -> int:
- return self._get_object(object_name).size
+ obj = self._get_object(object_name)
+ assert obj is not None
+ return obj.size
def download_object(
self,
@@ -178,6 +180,7 @@ def download_object(
tmp_filepath = str(filename) + f'.{uuid.uuid4()}.tmp'
try:
with open(tmp_filepath, 'wb+') as f:
+ assert obj is not None
stream = self._provider.download_object_as_stream(obj, chunk_size=self.chunk_size)
for chunk in iterate_with_callback(stream, obj.size, callback):
f.write(chunk)
diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py
index 15f50bcdb0..a156007dae 100644
--- a/composer/utils/object_store/mlflow_object_store.py
+++ b/composer/utils/object_store/mlflow_object_store.py
@@ -21,8 +21,11 @@
DEFAULT_MLFLOW_EXPERIMENT_NAME = 'mlflow-object-store'
-PLACEHOLDER_EXPERIMENT_ID = '{mlflow_experiment_id}'
-PLACEHOLDER_RUN_ID = '{mlflow_run_id}'
+MLFLOW_EXPERIMENT_ID_FORMAT_KEY = 'mlflow_experiment_id'
+MLFLOW_RUN_ID_FORMAT_KEY = 'mlflow_run_id'
+
+MLFLOW_EXPERIMENT_ID_PLACEHOLDER = '{' + MLFLOW_EXPERIMENT_ID_FORMAT_KEY + '}'
+MLFLOW_RUN_ID_PLACEHOLDER = '{' + MLFLOW_RUN_ID_FORMAT_KEY + '}'
log = logging.getLogger(__name__)
@@ -112,7 +115,10 @@ def __init__(self, path: str, multipart_upload_chunk_size: int = 100 * 1024 * 10
except ImportError as e:
raise MissingConditionalImportError('databricks', conda_package='databricks-sdk>=0.15.0,<1.0') from e
- tracking_uri = os.getenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, MLFLOW_DATABRICKS_TRACKING_URI)
+ tracking_uri = os.getenv(
+ mlflow.environment_variables.MLFLOW_TRACKING_URI.name, # pyright: ignore[reportGeneralTypeIssues]
+ MLFLOW_DATABRICKS_TRACKING_URI,
+ )
if tracking_uri != MLFLOW_DATABRICKS_TRACKING_URI:
raise ValueError(
'MLFlowObjectStore currently only supports Databricks-hosted MLflow tracking. '
@@ -129,12 +135,13 @@ def __init__(self, path: str, multipart_upload_chunk_size: int = 100 * 1024 * 10
'to identify different ways to setup credentials.') from e
self._mlflow_client = MlflowClient(tracking_uri)
- mlflow.environment_variables.MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.set(multipart_upload_chunk_size)
+ mlflow.environment_variables.MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.set( # pyright: ignore[reportGeneralTypeIssues]
+ multipart_upload_chunk_size,)
experiment_id, run_id, _ = MLFlowObjectStore.parse_dbfs_path(path)
- if experiment_id == PLACEHOLDER_EXPERIMENT_ID:
+ if experiment_id == MLFLOW_EXPERIMENT_ID_PLACEHOLDER:
experiment_id = None
- if run_id == PLACEHOLDER_RUN_ID:
+ if run_id == MLFLOW_RUN_ID_PLACEHOLDER:
run_id = None
# Construct the `experiment_id` and `run_id` depending on whether format placeholders were provided.
@@ -158,8 +165,8 @@ def _init_run_info(self, experiment_id: Optional[str], run_id: Optional[str]) ->
log.debug(f'MLFlowObjectStore using active MLflow run {run_id=}')
else:
# If no active run exists, create a new run for the default experiment.
- experiment_name = os.getenv(mlflow.environment_variables.MLFLOW_EXPERIMENT_NAME.name,
- DEFAULT_MLFLOW_EXPERIMENT_NAME)
+ mlflow_env_var_name = mlflow.environment_variables.MLFLOW_EXPERIMENT_NAME.name # pyright: ignore[reportGeneralTypeIssues]
+ experiment_name = os.getenv(mlflow_env_var_name, DEFAULT_MLFLOW_EXPERIMENT_NAME)
experiment = self._mlflow_client.get_experiment_by_name(experiment_name)
if experiment is not None:
@@ -236,10 +243,10 @@ def get_artifact_path(self, object_name: str) -> str:
"""
if object_name.startswith(MLFLOW_DBFS_PATH_PREFIX):
experiment_id, run_id, object_name = self.parse_dbfs_path(object_name)
- if (experiment_id != self.experiment_id and experiment_id != PLACEHOLDER_EXPERIMENT_ID):
+ if (experiment_id != self.experiment_id and experiment_id != MLFLOW_EXPERIMENT_ID_PLACEHOLDER):
raise ValueError(f'Object {object_name} belongs to experiment ID {experiment_id}, '
f'but MLFlowObjectStore is associated with experiment ID {self.experiment_id}.')
- if (run_id != self.run_id and run_id != PLACEHOLDER_RUN_ID):
+ if (run_id != self.run_id and run_id != MLFLOW_RUN_ID_PLACEHOLDER):
raise ValueError(f'Object {object_name} belongs to run ID {run_id}, '
f'but MLFlowObjectStore is associated with run ID {self.run_id}.')
return object_name
diff --git a/composer/utils/object_store/oci_object_store.py b/composer/utils/object_store/oci_object_store.py
index ce3fd5ea2c..d36b13e03b 100644
--- a/composer/utils/object_store/oci_object_store.py
+++ b/composer/utils/object_store/oci_object_store.py
@@ -32,7 +32,7 @@ def _reraise_oci_errors(uri: str, e: Exception):
raise FileNotFoundError(f'Object {uri} not found. {e.message}') from e # type: ignore
if e.code == 'BucketNotFound': # type: ignore
raise ValueError(f'Bucket specified in {uri} not found. {e.message}') from e # type: ignore
- raise e
+ raise FileNotFoundError(f'Object {uri} not found with no error code. {e.message}') from e # type: ignore
# Client errors
if isinstance(e, oci.exceptions.ClientError):
@@ -81,7 +81,7 @@ def __init__(
except Exception as e:
_reraise_oci_errors(self.get_uri(object_name=''), e)
- self.namespace = self.client.get_namespace().data
+ self.namespace = self.client.get_namespace().data # pyright: ignore[reportOptionalMemberAccess]
self.upload_manager = oci.object_storage.UploadManager(self.client)
def get_uri(self, object_name: str) -> str:
@@ -97,10 +97,12 @@ def get_object_size(self, object_name: str) -> int:
except Exception as e:
_reraise_oci_errors(self.get_uri(object_name), e)
- if response.status == 200:
- return int(response.data.headers['Content-Length'])
+ if response.status == 200: # pyright: ignore[reportUnboundVariable, reportOptionalMemberAccess]
+ data = response.data # pyright: ignore[reportUnboundVariable, reportOptionalMemberAccess]
+ return int(data.headers['Content-Length'])
else:
- raise ValueError(f'OCI get_object was not successful with a {response.status} status code.')
+ status = response.status # pyright: ignore[reportUnboundVariable, reportOptionalMemberAccess]
+ raise ValueError(f'OCI get_object was not successful with a {status} status code.')
def upload_object(
self,
@@ -126,7 +128,7 @@ def _download_part(self, object_name, filename, start_byte, end_byte, part_numbe
object_name=object_name,
range=range_header)
with open(tmp_part_path, 'wb') as f:
- f.write(response.data.content)
+ f.write(response.data.content) # pyright: ignore[reportOptionalMemberAccess]
return part_number, tmp_part_path
def download_object(
@@ -146,8 +148,12 @@ def download_object(
os.makedirs(dirname, exist_ok=True)
# Get the size of the object
- head_object_response = self.client.head_object(self.namespace, self.bucket, object_name)
- object_size = head_object_response.headers['content-length']
+ object_size = 0
+ try:
+ head_object_response = self.client.head_object(self.namespace, self.bucket, object_name)
+ object_size = head_object_response.headers['content-length'] # pyright: ignore[reportOptionalMemberAccess]
+ except Exception as e:
+ _reraise_oci_errors(self.get_uri(object_name), e)
# Calculate the part sizes
base_part_size, remainder = divmod(int(object_size), num_parts)
part_sizes = [base_part_size] * num_parts
@@ -156,9 +162,9 @@ def download_object(
part_sizes = [part_size for part_size in part_sizes if part_size > 0]
with TemporaryDirectory(dir=dirname, prefix=f'{str(filename)}') as temp_dir:
+ parts = []
try:
# Download parts in parallel
- parts = []
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
start_byte = 0
@@ -198,10 +204,9 @@ def list_objects(self, prefix: Optional[str] = None) -> List[str]:
response_complete = False
try:
while not response_complete:
- response = self.client.list_objects(namespace_name=self.namespace,
- bucket_name=self.bucket,
- prefix=prefix,
- start=next_start_with).data
+ response = self.client.list_objects(
+ namespace_name=self.namespace, bucket_name=self.bucket, prefix=prefix,
+ start=next_start_with).data # pyright: ignore[reportOptionalMemberAccess]
object_names.extend([obj.name for obj in response.objects])
next_start_with = response.next_start_with
if not next_start_with:
diff --git a/composer/utils/object_store/s3_object_store.py b/composer/utils/object_store/s3_object_store.py
index 854d447665..eeae4d28fa 100644
--- a/composer/utils/object_store/s3_object_store.py
+++ b/composer/utils/object_store/s3_object_store.py
@@ -116,6 +116,7 @@ def get_key(self, object_name: str) -> str:
return f'{self.prefix}{object_name}'
def get_object_size(self, object_name: str) -> int:
+ obj = {'ContentLength': -1}
try:
obj = self.client.get_object(Bucket=self.bucket, Key=self.get_key(object_name))
except Exception as e:
diff --git a/composer/utils/object_store/uc_object_store.py b/composer/utils/object_store/uc_object_store.py
index 23e8440354..4fc901212a 100644
--- a/composer/utils/object_store/uc_object_store.py
+++ b/composer/utils/object_store/uc_object_store.py
@@ -24,8 +24,9 @@
def _wrap_errors(uri: str, e: Exception):
from databricks.sdk.core import DatabricksError
+ from databricks.sdk.errors.mapping import NotFound
if isinstance(e, DatabricksError):
- if e.error_code == _NOT_FOUND_ERROR_CODE: # type: ignore
+ if isinstance(e, NotFound) or e.error_code == _NOT_FOUND_ERROR_CODE: # type: ignore
raise FileNotFoundError(f'Object {uri} not found') from e
raise ObjectStoreTransientError from e
@@ -48,6 +49,7 @@ class UCObjectStore(ObjectStore):
"""
_UC_VOLUME_LIST_API_ENDPOINT = '/api/2.0/fs/list'
+ _UC_VOLUME_FILES_API_ENDPOINT = '/api/2.0/fs/files'
def __init__(self, path: str) -> None:
try:
@@ -206,21 +208,21 @@ def get_object_size(self, object_name: str) -> int:
"""
from databricks.sdk.core import DatabricksError
try:
- file_info = self.client.files.get_status(self._get_object_path(object_name))
- if file_info.is_dir:
- raise IsADirectoryError(f'{object_name} is a UC directory, not a file.')
-
- return file_info.file_size # pyright: ignore
+ # Note: The UC team is working on changes to fix the files.get_status API, but it currently
+ # does not work. Once fixed, we will call the files API endpoint. We currently only use this
+ # function in Composer and LLM-foundry to check the UC object's existence.
+ object_path = self._get_object_path(object_name).lstrip('/')
+ path = os.path.join(self._UC_VOLUME_FILES_API_ENDPOINT, object_path)
+ self.client.api_client.do(method='HEAD', path=path, headers={'Source': 'mosaicml/composer'})
+ return 1000000 # Dummy value, as we don't have a way to get the size of the file
except DatabricksError as e:
+ # If the code reaches here, the file was not found
_wrap_errors(self.get_uri(object_name), e)
+ return -1
def list_objects(self, prefix: Optional[str]) -> List[str]:
"""List all objects in the object store with the given prefix.
- .. note::
-
- This function removes the directories from the returned list.
-
Args:
prefix (str): The prefix to search for.
@@ -232,13 +234,35 @@ def list_objects(self, prefix: Optional[str]) -> List[str]:
from databricks.sdk.core import DatabricksError
try:
- data = json.dumps({'path': self._get_object_path(prefix)})
# NOTE: This API is in preview and should not be directly used outside of this instance
- resp = self.client.api_client.do(method='GET',
- path=self._UC_VOLUME_LIST_API_ENDPOINT,
- data=data,
- headers={'Source': 'mosaicml/composer'})
- assert isinstance(resp, dict)
- return [f['path'] for f in resp.get('files', []) if not f['is_dir']]
+ logging.warn('UCObjectStore.list_objects is experimental.')
+
+ # Iteratively get all UC Volume files with `prefix`.
+ stack = [prefix]
+ all_files = []
+
+ while len(stack) > 0:
+ current_path = stack.pop()
+
+ # Note: Databricks SDK handles HTTP errors and retries.
+ # See https://github.com/databricks/databricks-sdk-py/blob/v0.18.0/databricks/sdk/core.py#L125 and
+ # https://github.com/databricks/databricks-sdk-py/blob/v0.18.0/databricks/sdk/retries.py#L33 .
+ resp = self.client.api_client.do(method='GET',
+ path=self._UC_VOLUME_LIST_API_ENDPOINT,
+ data=json.dumps({'path': self._get_object_path(current_path)}),
+ headers={'Source': 'mosaicml/composer'})
+
+ assert isinstance(resp, dict), 'Response is not a dictionary'
+
+ for f in resp.get('files', []):
+ fpath = f['path']
+ if f['is_dir']:
+ stack.append(fpath)
+ else:
+ all_files.append(fpath)
+
+ return all_files
+
except DatabricksError as e:
_wrap_errors(self.get_uri(prefix), e)
+ return []
diff --git a/composer/utils/reproducibility.py b/composer/utils/reproducibility.py
index 0895b530d9..0e2ee0cb5f 100644
--- a/composer/utils/reproducibility.py
+++ b/composer/utils/reproducibility.py
@@ -53,6 +53,7 @@
import textwrap
import time
import warnings
+from contextlib import contextmanager
from typing import Any, Dict, List
import numpy as np
@@ -62,6 +63,7 @@
from composer.utils import dist
__all__ = [
+ 'seed_context',
'configure_deterministic_mode',
'get_random_seed',
'seed_all',
@@ -76,6 +78,15 @@
MAX_SEED = 2**32 - 1
+@contextmanager
+def seed_context(seed: int):
+ """Context manager to store rng_state and reseed for duration of context."""
+ rng_state = get_rng_state()
+ seed_all(seed)
+ yield
+ load_rng_state(rng_state)
+
+
def configure_deterministic_mode():
"""Configure PyTorch deterministic mode.
@@ -218,7 +229,7 @@ def load_rng_state(rng_state_dicts: List[Dict[str, Any]]):
try:
torch.cuda.set_rng_state(rng_state_dict['cuda'])
except RuntimeError as e:
- if 'RNG state is wrong size' in str(e):
+ if 'RNG state is wrong size' in str(e) or 'offset must be a multiple of 4' in str(e):
warnings.warn('The CUDA RNG state could not be loaded from the checkpoint, '
'likely because a different version of torch was used to save the '
'checkpoint. Skipping loading the CUDA RNG state.')
diff --git a/composer/utils/string_enum.py b/composer/utils/string_enum.py
index ba4e534e0c..18a98f9339 100644
--- a/composer/utils/string_enum.py
+++ b/composer/utils/string_enum.py
@@ -64,7 +64,7 @@ class StringEnum(Enum):
warnings.resetwarnings()
"""
- __hash__ = Enum.__hash__
+ __hash__ = Enum.__hash__ # pyright: ignore[reportGeneralTypeIssues]
def __eq__(self, other: object) -> bool:
if isinstance(other, str):
diff --git a/docker/Dockerfile b/docker/Dockerfile
index ea72ebc7b4..e5ae9b9468 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -118,6 +118,7 @@ RUN apt-get update && \
tcl \
libjpeg8-dev \
less \
+ libsnappy-dev \
# For AWS EFA:
autoconf \
autotools-dev \
@@ -269,6 +270,7 @@ RUN if [ -n "$MOFED_VERSION" ] ; then \
rm -rf /tmp/mofed ; \
fi
+
#####################
# Install NVIDIA Apex
#####################
@@ -294,7 +296,7 @@ RUN if [[ -n "$CUDA_VERSION" ]] && [[ -z "${PYTORCH_NIGHTLY_URL}" ]]; then \
RUN if [ -n "$CUDA_VERSION" ] ; then \
pip${PYTHON_VERSION} install --upgrade --no-cache-dir ninja==1.11.1 && \
pip${PYTHON_VERSION} install --upgrade --no-cache-dir --force-reinstall packaging==22.0 && \
- pip${PYTHON_VERSION} install --no-cache-dir flash-attn==1.0.9; \
+ MAX_JOBS=1 pip${PYTHON_VERSION} install --no-cache-dir flash-attn==2.5.0; \
fi
###############
@@ -353,7 +355,8 @@ RUN apt-get update && \
RUN pip install --no-cache-dir --upgrade \
certifi${CERTIFI_VERSION} \
ipython${IPYTHON_VERSION} \
- urllib3${URLLIB3_VERSION}
+ urllib3${URLLIB3_VERSION} \
+ python-snappy
##################################################
# Override NVIDIA mistaken env var for 11.8 images
diff --git a/docker/README.md b/docker/README.md
index 81f7f1fa0d..c617567f2f 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -15,8 +15,8 @@ all dependencies for both NLP and Vision models. They are built on top of the
| Composer Version | CUDA Support | Docker Tag |
|--------------------|----------------|----------------------------------------------------------------|
-| 0.17.2 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.17.2` |
-| 0.17.2 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.17.2_cpu` |
+| 0.19.1 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.19.1` |
+| 0.19.1 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.19.1_cpu` |
**Note**: For a lightweight installation, we recommended using a [MosaicML PyTorch Image](#pytorch-images) and manually
@@ -25,22 +25,23 @@ installing Composer within the image.
## PyTorch Images
The [`mosaicml/pytorch`](https://hub.docker.com/r/mosaicml/pytorch) images contain PyTorch preinstalled, without Composer.
-The base flavor contains PyTorch pre-installed; the vision flavor also includes OpenCV, MM Segmentation, and FFCV dependencies.
To install composer, once inside the image, run `pip install mosaicml`.
| Linux Distro | Flavor | PyTorch Version | CUDA Version | Python Version | Docker Tags |
|----------------|----------|-------------------|---------------------|------------------|------------------------------------------------------------------------------------------|
-| Ubuntu 20.04 | Base | 2.2.0 | 12.1.0 (Infiniband) | 3.10 | `mosaicml/pytorch:2.2.0_cu121-nightly20231213-python3.10-ubuntu20.04` |
-| Ubuntu 20.04 | Base | 2.1.1 | 12.1.0 (Infiniband) | 3.10 | `mosaicml/pytorch:latest`, `mosaicml/pytorch:2.1.1_cu121-python3.10-ubuntu20.04` |
-| Ubuntu 20.04 | Base | 2.1.1 | 12.1.0 (EFA) | 3.10 | `mosaicml/pytorch:latest-aws`, `mosaicml/pytorch:2.1.1_cu121-python3.10-ubuntu20.04-aws` |
-| Ubuntu 20.04 | Base | 2.1.1 | cpu | 3.10 | `mosaicml/pytorch:latest_cpu`, `mosaicml/pytorch:2.1.1_cpu-python3.10-ubuntu20.04` |
+| Ubuntu 20.04 | Base | 2.3.0 | 12.1.0 (Infiniband) | 3.11 | `mosaicml/pytorch:2.3.0_cu121-nightly20240110-python3.11-ubuntu20.04` |
+| Ubuntu 20.04 | Base | 2.3.0 | 12.1.0 (Infiniband) | 3.10 | `mosaicml/pytorch:2.3.0_cu121-nightly20240110-python3.10-ubuntu20.04` |
+| Ubuntu 20.04 | Base | 2.3.0 | 12.1.0 (EFA) | 3.10 | `mosaicml/pytorch:2.3.0_cu121-nightly20240110-python3.10-ubuntu20.04-aws` |
+| Ubuntu 20.04 | Base | 2.2.0 | 12.1.0 (Infiniband) | 3.11 | `mosaicml/pytorch:2.2.0_cu121-python3.11-ubuntu20.04` |
+| Ubuntu 20.04 | Base | 2.2.0 | 12.1.0 (EFA) | 3.11 | `mosaicml/pytorch:2.2.0_cu121-python3.11-ubuntu20.04-aws` |
+| Ubuntu 20.04 | Base | 2.2.0 | cpu | 3.11 | `mosaicml/pytorch:2.2.0_cpu-python3.11-ubuntu20.04` |
+| Ubuntu 20.04 | Base | 2.1.2 | 12.1.0 (Infiniband) | 3.10 | `mosaicml/pytorch:latest`, `mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04` |
+| Ubuntu 20.04 | Base | 2.1.2 | 12.1.0 (EFA) | 3.10 | `mosaicml/pytorch:latest-aws`, `mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04-aws` |
+| Ubuntu 20.04 | Base | 2.1.2 | cpu | 3.10 | `mosaicml/pytorch:latest_cpu`, `mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04` |
| Ubuntu 20.04 | Base | 2.0.1 | 11.8.0 (Infiniband) | 3.10 | `mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04` |
| Ubuntu 20.04 | Base | 2.0.1 | 11.8.0 (EFA) | 3.10 | `mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04-aws` |
| Ubuntu 20.04 | Base | 2.0.1 | cpu | 3.10 | `mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04` |
-| Ubuntu 20.04 | Base | 1.13.1 | 11.7.1 (Infiniband) | 3.10 | `mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04` |
-| Ubuntu 20.04 | Base | 1.13.1 | 11.7.1 (EFA) | 3.10 | `mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04-aws` |
-| Ubuntu 20.04 | Base | 1.13.1 | cpu | 3.10 | `mosaicml/pytorch:1.13.1_cpu-python3.10-ubuntu20.04` |
**Note**: The `mosaicml/pytorch:latest`, `mosaicml/pytorch:latest_cpu`, and `mosaicml/pytorch:latest-aws`
diff --git a/docker/build_matrix.yaml b/docker/build_matrix.yaml
index 98bac4504b..700bd4c010 100644
--- a/docker/build_matrix.yaml
+++ b/docker/build_matrix.yaml
@@ -2,7 +2,75 @@
- AWS_OFI_NCCL_VERSION: ''
BASE_IMAGE: nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04
CUDA_VERSION: 12.1.0
- IMAGE_NAME: torch-2-1-1-cu121
+ IMAGE_NAME: torch-2-2-0-cu121
+ MOFED_VERSION: 5.5-1.0.3.2
+ NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471
+ brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471
+ brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471
+ brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471
+ brand=tesla,driver>=510,driver<511 brand=unknown,driver>=510,driver<511 brand=nvidia,driver>=510,driver<511
+ brand=nvidiartx,driver>=510,driver<511 brand=geforce,driver>=510,driver<511 brand=geforcertx,driver>=510,driver<511
+ brand=quadro,driver>=510,driver<511 brand=quadrortx,driver>=510,driver<511 brand=titan,driver>=510,driver<511
+ brand=titanrtx,driver>=510,driver<511 brand=tesla,driver>=515,driver<516 brand=unknown,driver>=515,driver<516
+ brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=515,driver<516 brand=geforce,driver>=515,driver<516
+ brand=geforcertx,driver>=515,driver<516 brand=quadro,driver>=515,driver<516 brand=quadrortx,driver>=515,driver<516
+ brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=515,driver<516 brand=tesla,driver>=525,driver<526
+ brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526
+ brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526
+ brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526
+ PYTHON_VERSION: '3.11'
+ PYTORCH_NIGHTLY_URL: ''
+ PYTORCH_NIGHTLY_VERSION: ''
+ PYTORCH_VERSION: 2.2.0
+ TAGS:
+ - mosaicml/pytorch:2.2.0_cu121-python3.11-ubuntu20.04
+ TARGET: pytorch_stage
+ TORCHVISION_VERSION: 0.17.0
+- AWS_OFI_NCCL_VERSION: v1.7.4-aws
+ BASE_IMAGE: nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04
+ CUDA_VERSION: 12.1.0
+ IMAGE_NAME: torch-2-2-0-cu121-aws
+ MOFED_VERSION: ''
+ NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471
+ brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471
+ brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471
+ brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471
+ brand=tesla,driver>=510,driver<511 brand=unknown,driver>=510,driver<511 brand=nvidia,driver>=510,driver<511
+ brand=nvidiartx,driver>=510,driver<511 brand=geforce,driver>=510,driver<511 brand=geforcertx,driver>=510,driver<511
+ brand=quadro,driver>=510,driver<511 brand=quadrortx,driver>=510,driver<511 brand=titan,driver>=510,driver<511
+ brand=titanrtx,driver>=510,driver<511 brand=tesla,driver>=515,driver<516 brand=unknown,driver>=515,driver<516
+ brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=515,driver<516 brand=geforce,driver>=515,driver<516
+ brand=geforcertx,driver>=515,driver<516 brand=quadro,driver>=515,driver<516 brand=quadrortx,driver>=515,driver<516
+ brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=515,driver<516 brand=tesla,driver>=525,driver<526
+ brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526
+ brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526
+ brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526
+ PYTHON_VERSION: '3.11'
+ PYTORCH_NIGHTLY_URL: ''
+ PYTORCH_NIGHTLY_VERSION: ''
+ PYTORCH_VERSION: 2.2.0
+ TAGS:
+ - mosaicml/pytorch:2.2.0_cu121-python3.11-ubuntu20.04-aws
+ TARGET: pytorch_stage
+ TORCHVISION_VERSION: 0.17.0
+- AWS_OFI_NCCL_VERSION: ''
+ BASE_IMAGE: ubuntu:20.04
+ CUDA_VERSION: ''
+ IMAGE_NAME: torch-2-2-0-cpu
+ MOFED_VERSION: ''
+ NVIDIA_REQUIRE_CUDA_OVERRIDE: ''
+ PYTHON_VERSION: '3.11'
+ PYTORCH_NIGHTLY_URL: ''
+ PYTORCH_NIGHTLY_VERSION: ''
+ PYTORCH_VERSION: 2.2.0
+ TAGS:
+ - mosaicml/pytorch:2.2.0_cpu-python3.11-ubuntu20.04
+ TARGET: pytorch_stage
+ TORCHVISION_VERSION: 0.17.0
+- AWS_OFI_NCCL_VERSION: ''
+ BASE_IMAGE: nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04
+ CUDA_VERSION: 12.1.0
+ IMAGE_NAME: torch-2-1-2-cu121
MOFED_VERSION: 5.5-1.0.3.2
NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471
brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471
@@ -21,16 +89,16 @@
PYTHON_VERSION: '3.10'
PYTORCH_NIGHTLY_URL: ''
PYTORCH_NIGHTLY_VERSION: ''
- PYTORCH_VERSION: 2.1.1
+ PYTORCH_VERSION: 2.1.2
TAGS:
- - mosaicml/pytorch:2.1.1_cu121-python3.10-ubuntu20.04
+ - mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04
- mosaicml/pytorch:latest
TARGET: pytorch_stage
- TORCHVISION_VERSION: 0.16.1
+ TORCHVISION_VERSION: 0.16.2
- AWS_OFI_NCCL_VERSION: v1.7.4-aws
BASE_IMAGE: nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04
CUDA_VERSION: 12.1.0
- IMAGE_NAME: torch-2-1-1-cu121-aws
+ IMAGE_NAME: torch-2-1-2-cu121-aws
MOFED_VERSION: ''
NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471
brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471
@@ -49,27 +117,27 @@
PYTHON_VERSION: '3.10'
PYTORCH_NIGHTLY_URL: ''
PYTORCH_NIGHTLY_VERSION: ''
- PYTORCH_VERSION: 2.1.1
+ PYTORCH_VERSION: 2.1.2
TAGS:
- - mosaicml/pytorch:2.1.1_cu121-python3.10-ubuntu20.04-aws
+ - mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04-aws
- mosaicml/pytorch:latest-aws
TARGET: pytorch_stage
- TORCHVISION_VERSION: 0.16.1
+ TORCHVISION_VERSION: 0.16.2
- AWS_OFI_NCCL_VERSION: ''
BASE_IMAGE: ubuntu:20.04
CUDA_VERSION: ''
- IMAGE_NAME: torch-2-1-1-cpu
+ IMAGE_NAME: torch-2-1-2-cpu
MOFED_VERSION: ''
NVIDIA_REQUIRE_CUDA_OVERRIDE: ''
PYTHON_VERSION: '3.10'
PYTORCH_NIGHTLY_URL: ''
PYTORCH_NIGHTLY_VERSION: ''
- PYTORCH_VERSION: 2.1.1
+ PYTORCH_VERSION: 2.1.2
TAGS:
- - mosaicml/pytorch:2.1.1_cpu-python3.10-ubuntu20.04
+ - mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
- mosaicml/pytorch:latest_cpu
TARGET: pytorch_stage
- TORCHVISION_VERSION: 0.16.1
+ TORCHVISION_VERSION: 0.16.2
- AWS_OFI_NCCL_VERSION: ''
BASE_IMAGE: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
CUDA_VERSION: 11.8.0
@@ -122,52 +190,64 @@
- mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
TARGET: pytorch_stage
TORCHVISION_VERSION: 0.15.2
-- AWS_OFI_NCCL_VERSION: ''
- BASE_IMAGE: nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04
- CUDA_VERSION: 11.7.1
- IMAGE_NAME: torch-1-13-1-cu117
- MOFED_VERSION: 5.5-1.0.3.2
- NVIDIA_REQUIRE_CUDA_OVERRIDE: ''
- PYTHON_VERSION: '3.10'
- PYTORCH_NIGHTLY_URL: ''
- PYTORCH_NIGHTLY_VERSION: ''
- PYTORCH_VERSION: 1.13.1
- TAGS:
- - mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
- TARGET: pytorch_stage
- TORCHVISION_VERSION: 0.14.1
- AWS_OFI_NCCL_VERSION: v1.7.4-aws
- BASE_IMAGE: nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04
- CUDA_VERSION: 11.7.1
- IMAGE_NAME: torch-1-13-1-cu117-aws
+ BASE_IMAGE: nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04
+ CUDA_VERSION: 12.1.0
+ IMAGE_NAME: torch-nightly-2-3-0-20240110-cu121-python3-10-aws
MOFED_VERSION: ''
- NVIDIA_REQUIRE_CUDA_OVERRIDE: ''
+ NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471
+ brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471
+ brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471
+ brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471
+ brand=tesla,driver>=510,driver<511 brand=unknown,driver>=510,driver<511 brand=nvidia,driver>=510,driver<511
+ brand=nvidiartx,driver>=510,driver<511 brand=geforce,driver>=510,driver<511 brand=geforcertx,driver>=510,driver<511
+ brand=quadro,driver>=510,driver<511 brand=quadrortx,driver>=510,driver<511 brand=titan,driver>=510,driver<511
+ brand=titanrtx,driver>=510,driver<511 brand=tesla,driver>=515,driver<516 brand=unknown,driver>=515,driver<516
+ brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=515,driver<516 brand=geforce,driver>=515,driver<516
+ brand=geforcertx,driver>=515,driver<516 brand=quadro,driver>=515,driver<516 brand=quadrortx,driver>=515,driver<516
+ brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=515,driver<516 brand=tesla,driver>=525,driver<526
+ brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526
+ brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526
+ brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526
PYTHON_VERSION: '3.10'
- PYTORCH_NIGHTLY_URL: ''
- PYTORCH_NIGHTLY_VERSION: ''
- PYTORCH_VERSION: 1.13.1
+ PYTORCH_NIGHTLY_URL: https://download.pytorch.org/whl/nightly/cu121
+ PYTORCH_NIGHTLY_VERSION: dev20240110+cu121
+ PYTORCH_VERSION: 2.3.0
TAGS:
- - mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04-aws
+ - mosaicml/pytorch:2.3.0_cu121-nightly20240110-python3.10-ubuntu20.04-aws
TARGET: pytorch_stage
- TORCHVISION_VERSION: 0.14.1
+ TORCHVISION_VERSION: 0.18.0
- AWS_OFI_NCCL_VERSION: ''
- BASE_IMAGE: ubuntu:20.04
- CUDA_VERSION: ''
- IMAGE_NAME: torch-1-13-1-cpu
- MOFED_VERSION: ''
- NVIDIA_REQUIRE_CUDA_OVERRIDE: ''
+ BASE_IMAGE: nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04
+ CUDA_VERSION: 12.1.0
+ IMAGE_NAME: torch-nightly-2-3-0-20240110-cu121-python3-10
+ MOFED_VERSION: 5.5-1.0.3.2
+ NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471
+ brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471
+ brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471
+ brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471
+ brand=tesla,driver>=510,driver<511 brand=unknown,driver>=510,driver<511 brand=nvidia,driver>=510,driver<511
+ brand=nvidiartx,driver>=510,driver<511 brand=geforce,driver>=510,driver<511 brand=geforcertx,driver>=510,driver<511
+ brand=quadro,driver>=510,driver<511 brand=quadrortx,driver>=510,driver<511 brand=titan,driver>=510,driver<511
+ brand=titanrtx,driver>=510,driver<511 brand=tesla,driver>=515,driver<516 brand=unknown,driver>=515,driver<516
+ brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=515,driver<516 brand=geforce,driver>=515,driver<516
+ brand=geforcertx,driver>=515,driver<516 brand=quadro,driver>=515,driver<516 brand=quadrortx,driver>=515,driver<516
+ brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=515,driver<516 brand=tesla,driver>=525,driver<526
+ brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526
+ brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526
+ brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526
PYTHON_VERSION: '3.10'
- PYTORCH_NIGHTLY_URL: ''
- PYTORCH_NIGHTLY_VERSION: ''
- PYTORCH_VERSION: 1.13.1
+ PYTORCH_NIGHTLY_URL: https://download.pytorch.org/whl/nightly/cu121
+ PYTORCH_NIGHTLY_VERSION: dev20240110+cu121
+ PYTORCH_VERSION: 2.3.0
TAGS:
- - mosaicml/pytorch:1.13.1_cpu-python3.10-ubuntu20.04
+ - mosaicml/pytorch:2.3.0_cu121-nightly20240110-python3.10-ubuntu20.04
TARGET: pytorch_stage
- TORCHVISION_VERSION: 0.14.1
+ TORCHVISION_VERSION: 0.18.0
- AWS_OFI_NCCL_VERSION: ''
BASE_IMAGE: nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04
CUDA_VERSION: 12.1.0
- IMAGE_NAME: torch-nightly-2-2-0-20231213-cu121
+ IMAGE_NAME: torch-nightly-2-3-0-20240110-cu121-python3-11
MOFED_VERSION: 5.5-1.0.3.2
NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471
brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471
@@ -183,19 +263,19 @@
brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526
brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526
brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526
- PYTHON_VERSION: '3.10'
+ PYTHON_VERSION: '3.11'
PYTORCH_NIGHTLY_URL: https://download.pytorch.org/whl/nightly/cu121
- PYTORCH_NIGHTLY_VERSION: dev20231213+cu121
- PYTORCH_VERSION: 2.2.0
+ PYTORCH_NIGHTLY_VERSION: dev20240110+cu121
+ PYTORCH_VERSION: 2.3.0
TAGS:
- - mosaicml/pytorch:2.2.0_cu121-nightly20231213-python3.10-ubuntu20.04
+ - mosaicml/pytorch:2.3.0_cu121-nightly20240110-python3.11-ubuntu20.04
TARGET: pytorch_stage
TORCHVISION_VERSION: 0.18.0
- AWS_OFI_NCCL_VERSION: ''
BASE_IMAGE: nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04
- COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.17.2
+ COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.19.1
CUDA_VERSION: 12.1.0
- IMAGE_NAME: composer-0-17-2
+ IMAGE_NAME: composer-0-19-1
MOFED_VERSION: 5.5-1.0.3.2
NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471
brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471
@@ -214,25 +294,25 @@
PYTHON_VERSION: '3.10'
PYTORCH_NIGHTLY_URL: ''
PYTORCH_NIGHTLY_VERSION: ''
- PYTORCH_VERSION: 2.1.1
+ PYTORCH_VERSION: 2.1.2
TAGS:
- - mosaicml/composer:0.17.2
+ - mosaicml/composer:0.19.1
- mosaicml/composer:latest
TARGET: composer_stage
- TORCHVISION_VERSION: 0.16.1
+ TORCHVISION_VERSION: 0.16.2
- AWS_OFI_NCCL_VERSION: ''
BASE_IMAGE: ubuntu:20.04
- COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.17.2
+ COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.19.1
CUDA_VERSION: ''
- IMAGE_NAME: composer-0-17-2-cpu
+ IMAGE_NAME: composer-0-19-1-cpu
MOFED_VERSION: 5.5-1.0.3.2
NVIDIA_REQUIRE_CUDA_OVERRIDE: ''
PYTHON_VERSION: '3.10'
PYTORCH_NIGHTLY_URL: ''
PYTORCH_NIGHTLY_VERSION: ''
- PYTORCH_VERSION: 2.1.1
+ PYTORCH_VERSION: 2.1.2
TAGS:
- - mosaicml/composer:0.17.2_cpu
+ - mosaicml/composer:0.19.1_cpu
- mosaicml/composer:latest_cpu
TARGET: composer_stage
- TORCHVISION_VERSION: 0.16.1
+ TORCHVISION_VERSION: 0.16.2
diff --git a/docker/generate_build_matrix.py b/docker/generate_build_matrix.py
index a7dca9bc50..333010304b 100644
--- a/docker/generate_build_matrix.py
+++ b/docker/generate_build_matrix.py
@@ -18,17 +18,17 @@
import tabulate
import yaml
-LATEST_PYTHON_VERSION = '3.10'
-PRODUCTION_PYTORCH_VERSION = '2.1.1'
+PRODUCTION_PYTHON_VERSION = '3.10'
+PRODUCTION_PYTORCH_VERSION = '2.1.2'
def _get_torchvision_version(pytorch_version: str):
- if pytorch_version == '2.1.1':
- return '0.16.1'
+ if pytorch_version == '2.2.0':
+ return '0.17.0'
+ if pytorch_version == '2.1.2':
+ return '0.16.2'
if pytorch_version == '2.0.1':
return '0.15.2'
- if pytorch_version == '1.13.1':
- return '0.14.1'
raise ValueError(f'Invalid pytorch_version: {pytorch_version}')
@@ -39,14 +39,15 @@ def _get_base_image(cuda_version: str):
def _get_cuda_version(pytorch_version: str, use_cuda: bool):
+ # From https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/
if not use_cuda:
return ''
- if pytorch_version == '2.1.1':
+ if pytorch_version == '2.2.0':
+ return '12.1.0'
+ if pytorch_version == '2.1.2':
return '12.1.0'
if pytorch_version == '2.0.1':
return '11.8.0'
- if pytorch_version == '1.13.1':
- return '11.7.1'
raise ValueError(f'Invalid pytorch_version: {pytorch_version}')
@@ -81,8 +82,7 @@ def _get_cuda_override(cuda_version: str):
'brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526')
return cuda_121_override_string
-
- if cuda_version == '11.8.0':
+ elif cuda_version == '11.8.0':
cuda_118_override_string = ('cuda>=11.8 brand=tesla,driver>=470,driver<471 '
'brand=tesla,driver>=515,driver<516 brand=unknown,driver>=470,driver<471 '
'brand=unknown,driver>=515,driver<516 brand=nvidia,driver>=470,driver<471 '
@@ -92,9 +92,7 @@ def _get_cuda_override(cuda_version: str):
'brand=quadro,driver>=515,driver<516 brand=titan,driver>=470,driver<471 '
'brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=470,driver<471 '
'brand=titanrtx,driver>=515,driver<516')
-
return cuda_118_override_string
-
return ''
@@ -106,7 +104,7 @@ def _get_pytorch_tags(python_version: str, pytorch_version: str, cuda_version: s
cuda_version_tag = _get_cuda_version_tag(cuda_version)
tags = [f'{base_image_name}:{pytorch_version}_{cuda_version_tag}-python{python_version}-ubuntu20.04']
- if python_version == LATEST_PYTHON_VERSION and pytorch_version == PRODUCTION_PYTORCH_VERSION:
+ if python_version == PRODUCTION_PYTHON_VERSION and pytorch_version == PRODUCTION_PYTORCH_VERSION:
if not cuda_version:
tags.append(f'{base_image_name}:latest_cpu')
else:
@@ -165,16 +163,15 @@ def _write_table(table_tag: str, table_contents: str):
def _main():
- python_versions = ['3.10']
- pytorch_versions = ['2.1.1', '2.0.1', '1.13.1']
+ python_pytorch_versions = [('3.11', '2.2.0'), ('3.10', '2.1.2'), ('3.10', '2.0.1')]
cuda_options = [True, False]
stages = ['pytorch_stage']
interconnects = ['mellanox', 'EFA'] # mellanox is default, EFA needed for AWS
pytorch_entries = []
- for product in itertools.product(python_versions, pytorch_versions, cuda_options, stages, interconnects):
- python_version, pytorch_version, use_cuda, stage, interconnect = product
+ for product in itertools.product(python_pytorch_versions, cuda_options, stages, interconnects):
+ (python_version, pytorch_version), use_cuda, stage, interconnect = product
cuda_version = _get_cuda_version(pytorch_version=pytorch_version, use_cuda=use_cuda)
@@ -209,9 +206,8 @@ def _main():
_get_cuda_override(cuda_version),
}
- # Only build EFA image on latest python with cuda on pytorch_stage
- if interconnect == 'EFA' and not (python_version == LATEST_PYTHON_VERSION and use_cuda and
- stage == 'pytorch_stage'):
+ # Only build EFA image on cuda and pytorch_stage
+ if interconnect == 'EFA' and not (use_cuda and stage == 'pytorch_stage'):
continue
# Skip the mellanox drivers if not in the cuda images or using EFA
@@ -227,27 +223,63 @@ def _main():
entry['AWS_OFI_NCCL_VERSION'] = 'v1.7.4-aws'
pytorch_entries.append(entry)
- nightly_entry = {
+
+ nightly_entry_310_aws = {
+ 'AWS_OFI_NCCL_VERSION': 'v1.7.4-aws',
+ 'BASE_IMAGE': 'nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04',
+ 'CUDA_VERSION': '12.1.0',
+ 'IMAGE_NAME': 'torch-nightly-2-3-0-20240110-cu121-python3-10-aws',
+ 'MOFED_VERSION': '',
+ 'NVIDIA_REQUIRE_CUDA_OVERRIDE': _get_cuda_override('12.1.0'),
+ 'PYTHON_VERSION': '3.10',
+ 'PYTORCH_VERSION': '2.3.0',
+ 'PYTORCH_NIGHTLY_URL': 'https://download.pytorch.org/whl/nightly/cu121',
+ 'PYTORCH_NIGHTLY_VERSION': 'dev20240110+cu121',
+ 'TAGS': ['mosaicml/pytorch:2.3.0_cu121-nightly20240110-python3.10-ubuntu20.04-aws'],
+ 'TARGET': 'pytorch_stage',
+ 'TORCHVISION_VERSION': '0.18.0'
+ }
+ pytorch_entries.append(nightly_entry_310_aws)
+
+ nightly_entry_310 = {
'AWS_OFI_NCCL_VERSION': '',
'BASE_IMAGE': 'nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04',
'CUDA_VERSION': '12.1.0',
- 'IMAGE_NAME': 'torch-nightly-2-2-0-20231213-cu121',
+ 'IMAGE_NAME': 'torch-nightly-2-3-0-20240110-cu121-python3-10',
'MOFED_VERSION': '5.5-1.0.3.2',
'NVIDIA_REQUIRE_CUDA_OVERRIDE': _get_cuda_override('12.1.0'),
'PYTHON_VERSION': '3.10',
- 'PYTORCH_VERSION': '2.2.0',
+ 'PYTORCH_VERSION': '2.3.0',
+ 'PYTORCH_NIGHTLY_URL': 'https://download.pytorch.org/whl/nightly/cu121',
+ 'PYTORCH_NIGHTLY_VERSION': 'dev20240110+cu121',
+ 'TAGS': ['mosaicml/pytorch:2.3.0_cu121-nightly20240110-python3.10-ubuntu20.04'],
+ 'TARGET': 'pytorch_stage',
+ 'TORCHVISION_VERSION': '0.18.0'
+ }
+ pytorch_entries.append(nightly_entry_310)
+
+ nightly_entry_311 = {
+ 'AWS_OFI_NCCL_VERSION': '',
+ 'BASE_IMAGE': 'nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04',
+ 'CUDA_VERSION': '12.1.0',
+ 'IMAGE_NAME': 'torch-nightly-2-3-0-20240110-cu121-python3-11',
+ 'MOFED_VERSION': '5.5-1.0.3.2',
+ 'NVIDIA_REQUIRE_CUDA_OVERRIDE': _get_cuda_override('12.1.0'),
+ 'PYTHON_VERSION': '3.11',
+ 'PYTORCH_VERSION': '2.3.0',
'PYTORCH_NIGHTLY_URL': 'https://download.pytorch.org/whl/nightly/cu121',
- 'PYTORCH_NIGHTLY_VERSION': 'dev20231213+cu121',
- 'TAGS': ['mosaicml/pytorch:2.2.0_cu121-nightly20231213-python3.10-ubuntu20.04'],
+ 'PYTORCH_NIGHTLY_VERSION': 'dev20240110+cu121',
+ 'TAGS': ['mosaicml/pytorch:2.3.0_cu121-nightly20240110-python3.11-ubuntu20.04'],
'TARGET': 'pytorch_stage',
'TORCHVISION_VERSION': '0.18.0'
}
- pytorch_entries.append(nightly_entry)
+ pytorch_entries.append(nightly_entry_311)
+
composer_entries = []
# The `GIT_COMMIT` is a placeholder and Jenkins will substitute it with the actual git commit for the `composer_staging` images
- composer_versions = ['0.17.2'] # Only build images for the latest composer version
- composer_python_versions = [LATEST_PYTHON_VERSION] # just build composer against the latest
+ composer_versions = ['0.19.1'] # Only build images for the latest composer version
+ composer_python_versions = [PRODUCTION_PYTHON_VERSION] # just build composer against the latest
for product in itertools.product(composer_python_versions, composer_versions, cuda_options):
python_version, composer_version, use_cuda = product
diff --git a/docs/source/composer_model.rst b/docs/source/composer_model.rst
index 3f4c32dab8..bd80be1d10 100644
--- a/docs/source/composer_model.rst
+++ b/docs/source/composer_model.rst
@@ -75,8 +75,6 @@ We also provide several common classes for various tasks, specifically:
- :class:`.ComposerClassifier` - classification tasks with a cross entropy
loss and accuracy metric.
-- :func:`.composer_timm` - creates classification models from the popular `TIMM`_
- library.
- :class:`.HuggingFaceModel` - :class:`.ComposerModel` wrapper for a đ¤ `Transformers`_ model.
.. note::
@@ -195,18 +193,6 @@ Integrations
------------
-
-TIMM
-~~~~
-
-Integrate with your favorite `TIMM`_ models with our :func:`.composer_timm` function.
-
-.. code:: python
-
- from composer.models import composer_timm
-
- timm_model = composer_timm(model_name='resnet50', pretrained=True)
-
BERT Example with đ¤ Transformers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -265,5 +251,4 @@ and make it compatible with our trainer.
.. |loss| replace:: :meth:`~.ComposerModel.loss`
.. _MMDetection: https://mmdetection.readthedocs.io/en/latest/
.. _Transformers: https://huggingface.co/docs/transformers/index
-.. _TIMM: https://timm.fast.ai/
.. _torchvision: https://pytorch.org/vision/stable/models.html
diff --git a/docs/source/doctest_fixtures.py b/docs/source/doctest_fixtures.py
index 2b640283b3..91b7c909b8 100644
--- a/docs/source/doctest_fixtures.py
+++ b/docs/source/doctest_fixtures.py
@@ -48,7 +48,6 @@
from composer.core import Timestamp as Timestamp
from composer.core import TimeUnit as TimeUnit
from composer.core import types as types
-from composer.datasets.synthetic import SyntheticBatchPairDataset
from composer.devices import DeviceCPU
from composer.loggers import InMemoryLogger as InMemoryLogger
from composer.loggers import Logger as Logger
@@ -72,6 +71,13 @@
except ImportError:
_COMETML_INSTALLED = False
+try:
+ import neptune
+ _NEPTUNE_INSTALLED = True
+ del neptune # unused
+except ImportError:
+ _NEPTUNE_INSTALLED = False
+
try:
import libcloud
_LIBCLOUD_INSTALLED = True
@@ -87,7 +93,7 @@
sys.path.insert(0, _repo_root)
from tests.common import SimpleModel
-from tests.common.datasets import RandomTextClassificationDataset
+from tests.common.datasets import RandomClassificationDataset, RandomTextClassificationDataset
# Disable mosaicml logger
os.environ['MOSAICML_PLATFORM'] = 'False'
@@ -112,11 +118,10 @@
scheduler = CosineAnnealingLR(optimizer, T_max=1)
-dataset = SyntheticBatchPairDataset(
- total_dataset_size=100,
- data_shape=data_shape,
+dataset = RandomClassificationDataset(
+ shape=data_shape,
+ size=100,
num_classes=num_classes,
- num_unique_samples_to_create=10,
)
train_dataset = dataset
diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst
index d55745608f..100247983a 100644
--- a/docs/source/getting_started/installation.rst
+++ b/docs/source/getting_started/installation.rst
@@ -20,16 +20,14 @@ the following installation targets are available:
and building documentation.
* ``pip install 'mosaicml[deepspeed]'``: Installs Composer with support for :mod:`deepspeed`.
* ``pip install 'mosaicml[nlp]'``: Installs Composer with support for NLP models and algorithms.
-* ``pip install 'mosaicml[unet]'``: Installs Composer with support for :doc:`Unet `.
-* ``pip install 'mosaicml[timm]'``: Installs Composer with support for :mod:`timm`.
* ``pip install 'mosaicml[wandb]'``: Installs Composer with support for :mod:`wandb`.
* ``pip install 'mosaicml[comet_ml]'``: Installs Composer with support for :mod:`comet_ml`.
+* ``pip install 'mosaicml[neptune]'``: Installs Composer with support for :mod:`neptune`.
* ``pip install 'mosaicml[tensorboard]'``: Installs Composer with support for :mod:`tensorboard`.
* ``pip install 'mosaicml[streaming]'``: Installs Composer with support for `streaming `_.
* ``pip install 'mosaicml[mlflow]'``: Installs Composer with support for :mod:`mlflow`.
* ``pip install 'mosaicml[oci]'``: Installs Composer with support for :mod:`oci`.
* ``pip install 'mosaicml[onnx]'``: Installs Composer with support for :mod:`onnx`.
-* ``pip install 'mosaicml[vit]'``: Installs Composer with support for :mod:`vit`.
* ``pip install 'mosaicml[coco]'``: Installs Composer with support for :mod:`coco`.
* ``pip install 'mosaicml[libcloud]'``: Installs Composer with support for :mod:`libcloud`.
* ``pip install 'mosaicml[all]'``: Installs all optional dependencies.
diff --git a/docs/source/getting_started/quick_start.rst b/docs/source/getting_started/quick_start.rst
index c3c7d6f7ed..f7613384ba 100644
--- a/docs/source/getting_started/quick_start.rst
+++ b/docs/source/getting_started/quick_start.rst
@@ -61,7 +61,7 @@ Besides easily running our built-in algorithms, Composer also features:
* An interface to flexibly add algorithms to the training loop
* An engine that manages the ordering of algorithms for composition
* A trainer to handle boilerplate around numerics, distributed training, and others
-* Integration with popular model libraries such as TIMM and HuggingFace Transformers
+* Integration with popular model libraries such as HuggingFace Transformers
Next steps
----------
diff --git a/docs/source/getting_started/welcome_tour.rst b/docs/source/getting_started/welcome_tour.rst
index a46dc85f33..649a9c87b0 100644
--- a/docs/source/getting_started/welcome_tour.rst
+++ b/docs/source/getting_started/welcome_tour.rst
@@ -65,6 +65,7 @@ We could add events to our training loop as follows:
.. code-block:: python
#
+ #
#
#
for epoch in range(NUM_EPOCHS):
diff --git a/docs/source/index.rst b/docs/source/index.rst
index ce95ba6e1b..425dcad93c 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -47,7 +47,6 @@ Composer is part of the broader Machine Learning community, and we welcome any c
examples/getting_started.ipynb
examples/functional_api.ipynb
- examples/medical_image_segmentation.ipynb
examples/custom_speedup_methods.ipynb
examples/finetune_huggingface.ipynb
examples/pretrain_finetune_huggingface.ipynb
@@ -136,19 +135,6 @@ Composer is part of the broader Machine Learning community, and we welcome any c
method_cards/swa.md
method_cards/weight_standardization.md
-.. toctree::
- :hidden:
- :maxdepth: 1
- :caption: Model Library
-
- model_cards/BERT.md
- model_cards/cifar_resnet.md
- model_cards/deeplabv3.md
- model_cards/efficientnet.md
- model_cards/GPT2.md
- model_cards/resnet.md
- model_cards/unet.md
-
.. toctree::
:hidden:
:caption: API Reference
diff --git a/docs/source/method_cards/decoupled_weight_decay.md b/docs/source/method_cards/decoupled_weight_decay.md
index 2d9f78f94f..71e0f4312f 100644
--- a/docs/source/method_cards/decoupled_weight_decay.md
+++ b/docs/source/method_cards/decoupled_weight_decay.md
@@ -16,9 +16,7 @@ L2 regularization is typically considered equivalent to weight decay, but this e
-
-```bash
-# Single GPU/CPU depending on torch.cuda.is_available()
-python train_resnet_imagenet1k.py /path/to/imagenet
-
-# Log experiments to Weights and Biases
-python train_resnet_imagenet1k.py /path/to/imagenet --wandb_logger --wandb_entity my_username --wandb_project my_project --wandb_run_name my_run_name
-
-# Single/Multi GPU training (infers the number of GPUs available)
-composer train_resnet_imagenet1k.py /path/to/imagenet
-
-# Manually specify number of GPUs to use:
-composer -n $N_GPUS train_resnet_imagenet1k.py /path/to/imagenet
-
-# Mild ResNet recipe for fastest training to ~76.5% accuracy:
-composer train_resnet_imagenet1k.py /path/to/imagenet --recipe_name mild --train_crop_size 176 --eval_crop_size 224 --max_duration 36ep --loss_name binary_cross_entropy
-
-# Medium ResNet recipe highest accuracy with similar training time as baseline:
-composer train_resnet_imagenet1k.py /path/to/imagenet --recipe_name medium --train_crop_size 176 --eval_crop_size 224 --max_duration 135ep --loss_name binary_cross_entropy
-
-# Spicy ResNet recipe for our most accurate ResNet over a long training schedule:
-composer train_resnet_imagenet1k.py /path/to/imagenet --recipe_name spicy --train_crop_size 176 --eval_crop_size 224 --max_duration 270ep --loss_name binary_cross_entropy
-```
diff --git a/examples/imagenet/train_resnet_imagenet1k.py b/examples/imagenet/train_resnet_imagenet1k.py
deleted file mode 100644
index d6f1dee008..0000000000
--- a/examples/imagenet/train_resnet_imagenet1k.py
+++ /dev/null
@@ -1,298 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""Example script to train a ResNet model on ImageNet."""
-
-import argparse
-import logging
-import os
-
-import torch
-from torch.utils.data import DataLoader
-from torchmetrics import MetricCollection
-from torchmetrics.classification import MulticlassAccuracy
-from torchvision import transforms
-from torchvision.datasets import ImageFolder
-from torchvision.models import resnet
-
-from composer import DataSpec, Time, Trainer
-from composer.algorithms import (EMA, SAM, BlurPool, ChannelsLast, ColOut, LabelSmoothing, MixUp, ProgressiveResizing,
- RandAugment, StochasticDepth)
-from composer.callbacks import CheckpointSaver, LRMonitor, SpeedMonitor
-from composer.datasets.utils import NormalizationFn, pil_image_collate
-from composer.loggers import WandBLogger
-from composer.loss import binary_cross_entropy_with_logits, soft_cross_entropy
-from composer.metrics import CrossEntropy
-from composer.models.tasks import ComposerClassifier
-from composer.optim import CosineAnnealingWithWarmupScheduler, DecoupledSGDW
-from composer.utils import dist
-
-logging.basicConfig()
-logging.getLogger().setLevel(logging.INFO)
-
-parser = argparse.ArgumentParser()
-
-# Dataloader arguments
-parser.add_argument('data_dir', help='Path to the directory containing the ImageNet-1k dataset', type=str)
-parser.add_argument('--train_crop_size', help='Training image crop size', type=int, default=224)
-parser.add_argument('--eval_resize_size', help='Evaluation image resize size', type=int, default=256)
-parser.add_argument('--eval_crop_size', help='Evaluation image crop size', type=int, default=224)
-parser.add_argument('--train_batch_size', help='Train dataloader per-device batch size', type=int, default=2048)
-parser.add_argument('--eval_batch_size', help='Validation dataloader per-device batch size', type=int, default=2048)
-
-# Model arguments
-parser.add_argument('--model_name',
- help='Name of the resnet model to train',
- default='resnet50',
- choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'])
-parser.add_argument('--loss_name',
- help='Name of the loss function to use for training',
- default='cross_entropy',
- choices=['cross_entropy', 'binary_cross_entropy'])
-
-# Optimizer arguments
-parser.add_argument('--learning_rate', help='Optimizer learning rate', type=float, default=2.048)
-parser.add_argument('--momentum', help='Optimizer momentum', type=float, default=0.875)
-parser.add_argument('--weight_decay', help='Optimizer weight decay', type=float, default=5.0e-4)
-
-# LR scheduler arguments
-parser.add_argument('--t_warmup',
- help='Duration of learning rate warmup specified as a Time string',
- type=Time.from_timestring,
- default='8ep')
-parser.add_argument('--t_max',
- help='Duration to cosine decay the learning rate specified as a Time string',
- type=Time.from_timestring,
- default='1dur')
-
-# Save checkpoint arguments
-parser.add_argument('--save_checkpoint_dir',
- help='Directory in which to save model checkpoints',
- type=str,
- default='checkpoints/{run_name}')
-parser.add_argument('--checkpoint_interval', help='Frequency to save checkpoints', type=str, default='1ep')
-
-# Load checkpoint arguments, assumes resuming the previous training run instead of fine-tuning
-parser.add_argument('--load_checkpoint_path', help='Path to the checkpoint to load', type=str)
-
-# Recipes
-parser.add_argument('--recipe_name',
- help='Either "mild", "medium" or "spicy" in order of increasing training time and accuracy',
- type=str,
- choices=['mild', 'medium', 'spicy'])
-
-# Logger parameters: progress bar logging is used by default
-# Only has Weights and Biases option to reduce the number of arguments. Other loggers can be substituted in the script
-parser.add_argument('--wandb_logger', help='Whether or not to log results to Weights and Biases', action='store_true')
-parser.add_argument('--wandb_entity', help='WandB entity name', type=str)
-parser.add_argument('--wandb_project', help='WandB project name', type=str)
-parser.add_argument('--wandb_run_name', help='WandB run name', type=str)
-
-# Trainer arguments
-parser.add_argument('--run_name', help='Name of the training run used for checkpointing and other logging', type=str)
-parser.add_argument('--seed', help='Random seed', type=int, default=17)
-parser.add_argument('--max_duration',
- help='Duration to train specified as a Time string',
- type=Time.from_timestring,
- default='90ep')
-parser.add_argument('--eval_interval',
- help='How frequently to run evaluation on the validation set specified as a Time string',
- type=Time.from_timestring,
- default='1ep')
-
-args = parser.parse_args()
-
-
-def _main():
-
- # Divide batch sizes by number of devices if running multi-gpu training
- if dist.get_world_size():
- args.train_batch_size = args.train_batch_size // dist.get_world_size()
- args.eval_batch_size = args.eval_batch_size // dist.get_world_size()
-
- # Scale by 255 since the collate `pil_image_collate` results in images in range 0-255
- # If using ToTensor() and the default collate, remove the scaling by 255
- IMAGENET_CHANNEL_MEAN = (0.485 * 255, 0.456 * 255, 0.406 * 255)
- IMAGENET_CHANNEL_STD = (0.229 * 255, 0.224 * 255, 0.225 * 255)
-
- # Train dataset
- logging.info('Building train dataloader')
- train_transforms = transforms.Compose([
- transforms.RandomResizedCrop(args.train_crop_size, scale=(0.08, 1.0), ratio=(0.75, 4.0 / 3.0)),
- transforms.RandomHorizontalFlip(),
- ])
- train_dataset = ImageFolder(os.path.join(args.data_dir, 'train'), train_transforms)
- # Nifty function to instantiate a PyTorch DistributedSampler based on your hardware setup
- train_sampler = dist.get_sampler(train_dataset, drop_last=True, shuffle=True)
- train_dataloader = DataLoader(
- train_dataset,
- batch_size=args.train_batch_size,
- num_workers=8,
- pin_memory=True,
- drop_last=True,
- sampler=train_sampler,
- collate_fn=pil_image_collate,
- persistent_workers=True, # Reduce overhead of creating new workers at the expense of using slightly more RAM
- )
- # DataSpec allows for on-gpu transformations, marginally relieving dataloader bottleneck
- train_dataspec = DataSpec(dataloader=train_dataloader,
- device_transforms=NormalizationFn(mean=IMAGENET_CHANNEL_MEAN, std=IMAGENET_CHANNEL_STD))
- logging.info('Built train dataloader\n')
-
- # Validation dataset
- logging.info('Building evaluation dataloader')
- eval_transforms = transforms.Compose([
- transforms.Resize(args.eval_resize_size),
- transforms.CenterCrop(args.eval_crop_size),
- ])
- eval_dataset = ImageFolder(os.path.join(args.data_dir, 'val'), eval_transforms)
- # Nifty function to instantiate a PyTorch DistributedSampler based on your hardware setup,
- eval_sampler = dist.get_sampler(eval_dataset, drop_last=False, shuffle=False)
- eval_dataloader = DataLoader(
- eval_dataset,
- batch_size=args.eval_batch_size,
- num_workers=8,
- pin_memory=True,
- drop_last=False,
- sampler=eval_sampler,
- collate_fn=pil_image_collate,
- persistent_workers=True, # Reduce overhead of creating new workers at the expense of using slightly more RAM
- )
- eval_dataspec = DataSpec(dataloader=eval_dataloader,
- device_transforms=NormalizationFn(mean=IMAGENET_CHANNEL_MEAN, std=IMAGENET_CHANNEL_STD))
- logging.info('Built evaluation dataloader\n')
-
- # Instantiate torchvision ResNet model
- logging.info('Building Composer model')
- model_fn = getattr(resnet, args.model_name)
- model = model_fn(num_classes=1000, groups=1, width_per_group=64)
-
- # Specify model initialization
- def weight_init(w: torch.nn.Module):
- if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d):
- torch.nn.init.kaiming_normal_(w.weight)
- if isinstance(w, torch.nn.BatchNorm2d):
- w.weight.data = torch.rand(w.weight.data.shape)
- w.bias.data = torch.zeros_like(w.bias.data)
- # When using binary cross entropy, set the classification layer bias to -log(num_classes)
- # to ensure the initial probabilities are approximately 1 / num_classes
- if args.loss_name == 'binary_cross_entropy' and isinstance(w, torch.nn.Linear):
- w.bias.data = torch.ones(w.bias.shape) * -torch.log(torch.tensor(w.bias.shape[0]))
-
- model.apply(weight_init)
-
- # Performance metrics to log other than training loss
- train_metrics = MulticlassAccuracy(num_classes=1000, average='micro')
- val_metrics = MetricCollection([CrossEntropy(), MulticlassAccuracy(num_classes=1000, average='micro')])
-
- # Cross entropy loss that can handle both index and one-hot targets
-
- if args.loss_name == 'binary_cross_entropy':
- loss_fn = binary_cross_entropy_with_logits
- else:
- loss_fn = soft_cross_entropy
-
- # Wrapper function to convert a classification PyTorch model into a Composer model
- composer_model = ComposerClassifier(model, train_metrics=train_metrics, val_metrics=val_metrics, loss_fn=loss_fn)
- logging.info('Built Composer model\n')
-
- # Optimizer
- logging.info('Building optimizer and learning rate scheduler')
- optimizer = DecoupledSGDW(composer_model.parameters(),
- lr=args.learning_rate,
- momentum=args.momentum,
- weight_decay=args.weight_decay)
-
- # Learning rate scheduler: LR warmup for 8 epochs, then cosine decay for the rest of training
- lr_scheduler = CosineAnnealingWithWarmupScheduler(t_warmup=args.t_warmup, t_max=args.t_max)
- logging.info('Built optimizer and learning rate scheduler\n')
-
- # Callbacks for logging
- logging.info('Building SpeedMonitor, LRMonitor, and CheckpointSaver callbacks')
- speed_monitor = SpeedMonitor(window_size=50) # Measures throughput as samples/sec and tracks total training time
- lr_monitor = LRMonitor() # Logs the learning rate
-
- # Callback for checkpointing
- checkpoint_saver = CheckpointSaver(folder=args.save_checkpoint_dir, save_interval=args.checkpoint_interval)
- logging.info('Built SpeedMonitor, LRMonitor, and CheckpointSaver callbacks\n')
-
- # Recipes for training ResNet architectures on ImageNet in order of increasing training time and accuracy
- # To learn about individual methods, check out "Methods Overview" in our documentation: https://docs.mosaicml.com/
- logging.info('Building algorithm recipes')
- if args.recipe_name == 'mild':
- algorithms = [
- BlurPool(),
- ChannelsLast(),
- EMA(half_life='100ba', update_interval='20ba'),
- ProgressiveResizing(initial_scale=0.5, delay_fraction=0.4, finetune_fraction=0.2),
- LabelSmoothing(smoothing=0.08),
- ]
- elif args.recipe_name == 'medium':
- algorithms = [
- BlurPool(),
- ChannelsLast(),
- EMA(half_life='100ba', update_interval='20ba'),
- ProgressiveResizing(initial_scale=0.5, delay_fraction=0.4, finetune_fraction=0.2),
- LabelSmoothing(smoothing=0.1),
- MixUp(alpha=0.2),
- SAM(rho=0.5, interval=10),
- ]
- elif args.recipe_name == 'spicy':
- algorithms = [
- BlurPool(),
- ChannelsLast(),
- EMA(half_life='100ba', update_interval='20ba'),
- ProgressiveResizing(initial_scale=0.6, delay_fraction=0.2, finetune_fraction=0.2),
- LabelSmoothing(smoothing=0.13),
- MixUp(alpha=0.25),
- SAM(rho=0.5, interval=5),
- ColOut(p_col=0.05, p_row=0.05),
- RandAugment(depth=1, severity=9),
- StochasticDepth(target_layer_name='ResNetBottleneck',
- stochastic_method='sample',
- drop_distribution='linear',
- drop_rate=0.1)
- ]
- else:
- algorithms = None
- logging.info('Built algorithm recipes\n')
-
- logger = None
- if args.wandb_logger:
- if args.wandb_entity is None:
- raise ValueError('Please specify --wandb_entity argument')
- if args.wandb_project is None:
- raise ValueError('Please specify --wandb_project argument')
- if args.wandb_run_name is None:
- raise ValueError('Please specify --wandb_run_name argument')
- logger = WandBLogger(entity=args.wandb_entity, project=args.wandb_project, name=args.wandb_run_name)
-
- # Create the Trainer!
- logging.info('Building Trainer')
- device = 'gpu' if torch.cuda.is_available() else 'cpu'
- precision = 'amp' if device == 'gpu' else 'fp32' # Mixed precision for fast training when using a GPU
- trainer = Trainer(run_name=args.run_name,
- model=composer_model,
- train_dataloader=train_dataspec,
- eval_dataloader=eval_dataspec,
- eval_interval=args.eval_interval,
- optimizers=optimizer,
- schedulers=lr_scheduler,
- algorithms=algorithms,
- loggers=logger,
- max_duration=args.max_duration,
- callbacks=[speed_monitor, lr_monitor, checkpoint_saver],
- load_path=args.load_checkpoint_path,
- device=device,
- precision=precision,
- device_train_microbatch_size='auto',
- seed=args.seed)
- logging.info('Built Trainer\n')
-
- # Start training!
- logging.info('Train!')
- trainer.fit()
-
-
-if __name__ == '__main__':
- _main()
diff --git a/examples/medical_image_segmentation.ipynb b/examples/medical_image_segmentation.ipynb
deleted file mode 100644
index d13f88fbea..0000000000
--- a/examples/medical_image_segmentation.ipynb
+++ /dev/null
@@ -1,725 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# đŠē Image Segmentation"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In this notebook you will use Composer and PyTorch to segment pneumothorax (air around or outside of the lungs) from chest radiographic images. This dataset was originally released for a [kaggle competition][kaggle] by the [Society for Informatics in Medicine][siim] (SIIM).\n",
- "\n",
- "**Disclaimer: This example represents a minimal working baseline. In order to get competitive results this notebook must run for a long time.**\n",
- "\n",
- "### Recommended Background\n",
- "\n",
- "This tutorial goes through the process of starting a project from scratch with Composer. It assumes you are fairly familiar with how such a process might look if working with PyTorch. In addition, it assumes some familiarity with computer vision models and methods.\n",
- "\n",
- "To better understand the Composer part, make sure you're comfortable with the material in our [Getting Started][getting_started] tutorial.\n",
- "\n",
- "### Tutorial Goals and Concepts Covered\n",
- "\n",
- "The goal of this tutorial is to provide an executable example of a computer vision project in Composer from the ground up.\n",
- "\n",
- "We will cover:\n",
- "\n",
- "- installing relevant packages\n",
- "- downloading the SIIM dataset from kaggle\n",
- "- cleaning and resampling the dataset\n",
- "- splitting data for validation\n",
- "- visualizing model inputs\n",
- "- training a baseline model with Composer\n",
- "- using Composer methods\n",
- "- next steps\n",
- "\n",
- "Let's get started!\n",
- "\n",
- "[kaggle]: https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/overview\n",
- "[siim]: https://siim.org/\n",
- "[getting_started]: https://docs.mosaicml.com/projects/composer/en/stable/examples/getting_started.html"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Setup\n",
- "\n",
- "Let's get started and configure our environment.\n",
- "\n",
- "### Install Dependencies\n",
- "\n",
- "If you haven't already, let's install the following dependencies, which are needed for this example:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "%pip install kaggle pydicom git+https://github.com/qubvel/segmentation_models.pytorch opencv-python-headless jupyterlab-widgets\n",
- "\n",
- "%pip install mosaicml\n",
- "# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
- "# %pip install git+https://github.com/mosaicml/composer.git"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Kaggle Authentication\n",
- "\n",
- "To access the data you need a Kaggle Account\n",
- "- accept competition terms https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/data\n",
- "- download `kaggle.json` from https://www.kaggle.com/yourusername/account by clicking \"Create new API token\"\n",
- "- make the `kaggle.json` file available to this notebook using the following code cells."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from ipywidgets import FileUpload\n",
- "from IPython.display import display\n",
- "uploader = FileUpload(accept='.json', multiple=True)\n",
- "display(uploader)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "\n",
- "kaggle_folder = os.path.join(os.path.expanduser(\"~\"), \".kaggle\")\n",
- "os.makedirs(kaggle_folder, exist_ok=True)\n",
- "kaggle_config_file = os.path.join(kaggle_folder, \"kaggle.json\")\n",
- "with open(kaggle_config_file, 'wb+') as output_file: \n",
- " for uploaded_filename in uploader.value:\n",
- " content = uploader.value[uploaded_filename]['content'] \n",
- " output_file.write(content) "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Download and unzip the data \n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "!kaggle datasets download -d seesee/siim-train-test\n",
- "!unzip -q siim-train-test.zip -d .\n",
- "!ls"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Flatten Image Directories\n",
- "The original dataset is oddly nested. We flatten it out so the images are easier to access in our pytorch dataset.\n",
- "\n",
- "`/siim/dicom-images-train/id/id/id.dcm` to `/siim/dicom-images-train/id.dcm`. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from pathlib import Path\n",
- "from tqdm.auto import tqdm\n",
- "\n",
- "train_images = list(Path('siim/dicom-images-train').glob('*/*/*.dcm'))\n",
- "for image in tqdm(train_images):\n",
- " image.replace(f'siim/dicom-images-train/{image.parts[-1]}')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Project setup"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Imports"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import itertools\n",
- "from ipywidgets import interact, fixed, IntSlider\n",
- "\n",
- "import numpy as np\n",
- "import pandas as pd\n",
- "import torch\n",
- "from torch import nn\n",
- "import matplotlib.pyplot as plt\n",
- "import cv2\n",
- "\n",
- "# model\n",
- "import segmentation_models_pytorch as smp\n",
- "\n",
- "# data\n",
- "from torch.utils.data import DataLoader, Dataset\n",
- "from torchvision.utils import draw_segmentation_masks, make_grid\n",
- "from pydicom.filereader import dcmread\n",
- "from sklearn.model_selection import StratifiedKFold\n",
- "\n",
- "# transforms\n",
- "from albumentations import ShiftScaleRotate, Resize, Compose\n",
- "\n",
- "from torchmetrics import Metric\n",
- "from torchmetrics.collections import MetricCollection\n",
- "\n",
- "# composer\n",
- "from composer import Trainer\n",
- "from composer.models import ComposerModel\n",
- "from composer.optim import DecoupledAdamW\n",
- "from composer.metrics.metrics import Dice"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Utils\n",
- "\n",
- "Here we define some utility functions to help with logging, decoding/encoding targets, and visualization."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "class LossMetric(Metric):\n",
- " \"\"\"Turns any torch.nn Loss Module into distributed torchmetrics Metric.\"\"\"\n",
- "\n",
- " def __init__(self, loss, dist_sync_on_step=False):\n",
- " super().__init__(dist_sync_on_step=dist_sync_on_step)\n",
- " self.loss = loss\n",
- " self.add_state(\"sum_loss\", default=torch.tensor(0.), dist_reduce_fx=\"sum\")\n",
- " self.add_state(\"total_batches\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n",
- "\n",
- " def update(self, preds, target):\n",
- " \"\"\"Update the state with new predictions and targets.\n",
- " \"\"\"\n",
- " # Loss calculated over samples/batch, accumulate loss over all batches\n",
- " self.sum_loss += self.loss(preds, target)\n",
- " self.total_batches += 1\n",
- "\n",
- " def compute(self):\n",
- " \"\"\"Aggregate state over all processes and compute the metric.\n",
- " \"\"\"\n",
- " # Return average loss over entire validation dataset\n",
- " return self.sum_loss / self.total_batches\n",
- "\n",
- "def rle2mask(rle, height=1024, width=1024, fill_value=1):\n",
- " mask = np.zeros((height, width), np.float32)\n",
- " mask = mask.reshape(-1)\n",
- " rle = np.array([int(s) for s in rle.strip().split(' ')])\n",
- " rle = rle.reshape(-1, 2)\n",
- " start = 0\n",
- " for index, length in rle:\n",
- " start = start+index\n",
- " end = start+length\n",
- " mask[start: end] = fill_value\n",
- " start = end\n",
- " mask = mask.reshape(width, height).T\n",
- " return mask\n",
- "\n",
- "def mask2rle(mask):\n",
- " mask = mask.T.flatten()\n",
- " start = np.where(mask[1:] > mask[:-1])[0]+1\n",
- " end = np.where(mask[:-1] > mask[1:])[0]+1\n",
- " length = end-start\n",
- " rle = []\n",
- " for i in range(len(length)):\n",
- " if i == 0:\n",
- " rle.extend([start[0], length[0]])\n",
- " else:\n",
- " rle.extend([start[i]-end[i-1], length[i]])\n",
- " rle = ' '.join([str(r) for r in rle])\n",
- " return rle"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Preprocessing and Data Science"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### SIIM Dataset\n",
- "\n",
- "The SIIM dataset consists of:\n",
- "- `dicom-images-train` - 12954 labeled images in [DICOM][dicom] format.\n",
- "- `dicom-images-test` - 3205 unlabeled DICOM images for testing\n",
- "\n",
- "- `train-rle.csv` comes with a label file `train-rle.csv` mapping `ImageId` to `EncodedPixels`.\n",
- "\n",
- " - `ImageId`s map to image paths for [DICOM][dicom_format] format images. \n",
- "\n",
- " - `EncodedPixels` are [run length encoded][masks] segmentation masks representing areas where pneumothorax has been labeled by an expert. A label of `\"-1\"` indicates the image was examined and no pneumothorax was found.\n",
- "\n",
- "[dicom]: https://pydicom.github.io/pydicom/stable/auto_examples/input_output/plot_read_dicom\n",
- "[dicom_format]: https://pydicom.github.io/pydicom/stable/auto_examples/input_output/plot_read_dicom.html#sphx-glr-auto-examples-input-output-plot-read-dicom-py\n",
- "[masks]: https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/mask.py"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "!ls siim"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "labels_df = pd.read_csv('siim/train-rle.csv')\n",
- "labels_df.shape"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Clean Data\n",
- "Of the ~13,000 images, only 3600 have masks. We will throw out some of the negative samples to better balance our dataset and speed up training."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "labels_df[labels_df[\" EncodedPixels\"] != \"-1\"].shape, labels_df[labels_df[\" EncodedPixels\"] == \"-1\"].shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def balance_labels(labels_df, extra_samples_without_mask=1500, random_state=1337):\n",
- " \"\"\"\n",
- " Drop duplicates and mark samples with masks.\n",
- " Sample 3576+extra_samples_without_mask unmasked samples to balance dataset.\n",
- " \"\"\"\n",
- " df = labels_df.drop_duplicates('ImageId')\n",
- " df_with_mask = df[df[\" EncodedPixels\"] != \"-1\"].copy(deep=True)\n",
- " df_with_mask['has_mask'] = 1\n",
- " df_without_mask = df[df[\" EncodedPixels\"] == \"-1\"].copy(deep=True)\n",
- " df_without_mask['has_mask'] = 0\n",
- " df_without_mask_sampled = df_without_mask.sample(len(df_with_mask)+extra_samples_without_mask, random_state=random_state)\n",
- " df = pd.concat([df_with_mask, df_without_mask_sampled])\n",
- " return df"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "df = balance_labels(labels_df)\n",
- "df.shape"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Create Cross Validation Splits\n",
- "Once cleaned and balanced, we're left with only 6838 images. This will leave us with rather small training and validation sets once we split the data. To mitigate the chances of us validating on a poorly sampled (not representative of our unlabeled test data) validation set, we use [StratifiedKFold][kfold] to create 5 different 80%-20%, `train` `eval` splits. \n",
- "\n",
- "**Note**: For datasets of this size, it's good practice to train and evaluate on each split, but due to runtime constraints in this notebook we will only train on the first split which contains 5470 training and 1368 eval samples.\n",
- "\n",
- "[kfold]: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedKFold.html"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=1337)\n",
- "train_idx, eval_idx = list(kfold.split(df[\"ImageId\"], df[\"has_mask\"]))[0]\n",
- "train_df, eval_df = df.iloc[train_idx], df.iloc[eval_idx]\n",
- "train_df.shape, eval_df.shape"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## PyTorch\n",
- "\n",
- "### PyTorch Dataset\n",
- "`SIIMDataset` is a standard PyTorch dataset that reads images and decodes labels from the siim label csv. DICOM images are loaded as grayscale numpy arrays, converted to rgb, and scaled. Labels are converted from rle strings to binary segmentation masks. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "class SIIMDataset(Dataset):\n",
- " def __init__(self, \n",
- " labels_df,\n",
- " transforms=None,\n",
- " image_dir=Path('siim/dicom-images-train')):\n",
- " self.labels_df = labels_df\n",
- " self.image_dir = image_dir\n",
- " self.transforms = transforms\n",
- "\n",
- " def __getitem__(self, idx):\n",
- " row = self.labels_df.iloc[idx]\n",
- " image_id = row.ImageId\n",
- " image_path = self.image_dir / f'{image_id}.dcm'\n",
- " image = dcmread(image_path).pixel_array # load dicom image\n",
- " image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) # convert rgb so we can keep imagenet first layer weights\n",
- " image = (image / 255.).astype('float32') # scale (0.- 1.)\n",
- "\n",
- " rle = row[' EncodedPixels']\n",
- " if rle != '-1':\n",
- " mask = rle2mask(rle, 1024, 1024).astype('float32')\n",
- " else:\n",
- " mask = np.zeros([1024, 1024]).astype('float32')\n",
- "\n",
- " if self.transforms:\n",
- " augmented = self.transforms(image=image, mask=mask)\n",
- " image = augmented['image']\n",
- " mask = augmented['mask']\n",
- "\n",
- " return (\n",
- " torch.from_numpy(image).permute(2, 0, 1),\n",
- " torch.from_numpy(mask).unsqueeze(0)\n",
- " )\n",
- "\n",
- " def __len__(self):\n",
- " return len(self.labels_df)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Transforms\n",
- "We use the [albumentations](https://albumentations.ai/docs/getting_started/mask_augmentation/) library to resize and randomly scale/rotate our training images. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "image_size = 512\n",
- "\n",
- "train_transforms = Compose(\n",
- " [\n",
- " Resize(image_size, image_size),\n",
- " ShiftScaleRotate(\n",
- " shift_limit=0,\n",
- " scale_limit=0.1,\n",
- " rotate_limit=10, # rotate\n",
- " p=0.5,\n",
- " border_mode=cv2.BORDER_CONSTANT\n",
- " )\n",
- " ]\n",
- ")\n",
- "\n",
- "eval_transforms = Compose([Resize(image_size, image_size)])\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### DataLoaders"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "\n",
- "train_batch_size = 32\n",
- "val_batch_size = 32\n",
- "\n",
- "train_dataloader = DataLoader(SIIMDataset(train_df, transforms=train_transforms),\n",
- " batch_size=train_batch_size, shuffle=True, num_workers=2)\n",
- "\n",
- "eval_dataloader = DataLoader(SIIMDataset(eval_df, transforms=eval_transforms),\n",
- " batch_size=val_batch_size, shuffle=False, num_workers=2)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Visualize batch\n",
- "Areas of pneumothorax are highlighted in red; drag the slider to iterate through batches."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "@interact(data_loader=fixed(train_dataloader), batch=IntSlider(min=0, max=len(train_dataloader)-1, step=1, value=0))\n",
- "def show_batch(data_loader, batch):\n",
- " plt.rcParams['figure.figsize'] = [20, 15]\n",
- "\n",
- " images, masks = list(itertools.islice(data_loader, batch, batch+1))[0]\n",
- " masks_list = []\n",
- " for image, mask in zip(images, masks):\n",
- " masked = draw_segmentation_masks((image * 255).byte(),\n",
- " mask.bool(), alpha=0.5, colors='red')\n",
- " masks_list.append(masked)\n",
- "\n",
- " grid = make_grid(masks_list, nrow=6)\n",
- " plt.imshow(grid.permute(1, 2, 0));"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Composer\n",
- "\n",
- "### Model\n",
- "\n",
- "Here we define a Composer model that wraps the smp [segmentation models pytorch][pytorch_seg] package. This lets us quickly create many different segmentation models made from common pre-trained PyTorch encoders. \n",
- "\n",
- "- We set defaults to create a [Unet][unet] from an ImageNet pre-trained ResNet-34 with 3 input channels for our RGB (converted) inputs and 1 output channel. \n",
- "- We set the default loss to `nn.BCEWithLogitsLoss()` to classify each pixel of the output.\n",
- "\n",
- "[pytorch_seg]: https://github.com/qubvel/segmentation_models.pytorch\n",
- "[unet]: https://arxiv.org/abs/1505.04597"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "class SMPUNet(ComposerModel):\n",
- " def __init__(self,\n",
- " encoder_name='resnet34',\n",
- " encoder_weights='imagenet',\n",
- " in_channels=3, classes=1,\n",
- " loss=nn.BCEWithLogitsLoss()):\n",
- " super().__init__()\n",
- " self.model = smp.Unet(\n",
- " encoder_name=encoder_name,\n",
- " encoder_weights=encoder_weights, # use `imagenet` pre-trained weights for encoder initialization\n",
- " in_channels=in_channels, # model input channels (1 for gray-scale images, 3 for RGB, etc.)\n",
- " classes=classes # model output channels (number of classes in your dataset)\n",
- " ) \n",
- "\n",
- " self.criterion = loss\n",
- " self.train_loss = LossMetric(loss)\n",
- " self.val_loss = LossMetric(loss)\n",
- " self.val_dice = Dice(num_classes=classes)\n",
- "\n",
- " def forward(self, batch):\n",
- " images, targets = batch\n",
- " return self.model(images)\n",
- "\n",
- " def loss(self, outputs, batch):\n",
- " _, targets = batch\n",
- " return self.criterion(outputs, targets)\n",
- "\n",
- " def get_metrics(self, is_train: bool = False):\n",
- " if is_train:\n",
- " return {'BCEWithLogitsLoss': self.train_loss}\n",
- " else:\n",
- " return {'BCEWithLogitsLoss': self.val_loss, 'Dice': self.dice}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "model = SMPUNet() # define unet model\n",
- "optimizer = DecoupledAdamW(model.parameters(), lr=1e-3)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Trainer"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "trainer = Trainer(\n",
- " model=model,\n",
- " train_dataloader=train_dataloader,\n",
- " eval_dataloader=eval_dataloader,\n",
- " max_duration='2ep',\n",
- " optimizers=optimizer,\n",
- " device='gpu',\n",
- " precision='amp',\n",
- " seed=1337\n",
- ")\n",
- "trainer.fit()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Algorithms\n",
- "\n",
- "Composer allows us to quickly experiment with algorithms that can speed up or improve the quality of our model. This is how we can add `CutOut` and `LabelSmoothing`\n",
- "\n",
- "Additionally, the Composer trainer has builtin support for automatic mixed precision training and gradient accumulation to help train quickly and simulate larger batch sizes."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from composer.algorithms import CutOut, LabelSmoothing\n",
- "\n",
- "model = SMPUNet() # define unet model\n",
- "optimizer = DecoupledAdamW(model.parameters(), lr=1e-3)\n",
- "\n",
- "algorithms = [CutOut(length=0.5), LabelSmoothing(smoothing=0.1)]\n",
- "\n",
- "trainer = Trainer(\n",
- " model=model,\n",
- " train_dataloader=train_dataloader,\n",
- " eval_dataloader=eval_dataloader,\n",
- " max_duration='2ep',\n",
- " optimizers=optimizer,\n",
- " algorithms=algorithms,\n",
- " device='gpu',\n",
- " precision='amp',\n",
- " seed=1337\n",
- ")\n",
- "trainer.fit()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "## What next?\n",
- "\n",
- "You've now seen a from-scratch demonstration of using Composer in a computer vision project. But don't stop here! If you're interested, we recommend that you continue to experiment with:\n",
- "\n",
- "- training longer\n",
- "- different loss functions, architectures, transformations, and\n",
- "- different combinations of composer methods!\n",
- "\n",
- "In addition, please continue to explore our tutorials! Here are a couple suggestions:\n",
- "\n",
- "* Continue to explore more advanced applications of Composer like [fine-tuning a transformer for sentiment classification][huggingface_tutorial].\n",
- "\n",
- "* Learn about callbacks and how to apply [early stopping][early_stopping_tutorial].\n",
- "\n",
- "* See how dataloading bottlenecks in computer vision can be addressed using [FFCV][ffcv].\n",
- "\n",
- "[image_segmentation_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/medical_image_segmentation.html\n",
- "[huggingface_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/huggingface_models.html\n",
- "[early_stopping_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/early_stopping.html\n",
- "[ffcv]: https://docs.mosaicml.com/projects/composer/en/stable/examples/ffcv_dataloaders.html"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Come get involved with MosaicML!\n",
- "\n",
- "We'd love for you to get involved with the MosaicML community in any of these ways:\n",
- "\n",
- "### [Star Composer on GitHub](https://github.com/mosaicml/composer)\n",
- "\n",
- "Help make others aware of our work by [starring Composer on GitHub](https://github.com/mosaicml/composer).\n",
- "\n",
- "### [Join the MosaicML Slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg)\n",
- "\n",
- "Head on over to the [MosaicML slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg) to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!\n",
- "\n",
- "### Contribute to Composer\n",
- "\n",
- "Is there a bug you noticed or a feature you'd like? File an [issue](https://github.com/mosaicml/composer/issues) or make a [pull request](https://github.com/mosaicml/composer/pulls)!"
- ]
- }
- ],
- "metadata": {
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 1
-}
diff --git a/examples/profiler_demo.py b/examples/profiler_demo.py
index f06fa17f06..d46c89e559 100644
--- a/examples/profiler_demo.py
+++ b/examples/profiler_demo.py
@@ -8,11 +8,13 @@
# [imports-start]
import torch
+import torch.nn as nn
+import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from composer import Trainer
-from composer.models import mnist_model
+from composer.models.tasks import ComposerClassifier
from composer.profiler import JSONTraceHandler, cyclic_schedule
from composer.profiler.profiler import Profiler
@@ -35,10 +37,39 @@
persistent_workers=True,
num_workers=8,
)
+
# [dataloader-end]
+
# Instantiate Model
-model = mnist_model(num_classes=10)
+class Model(nn.Module):
+ """Toy convolutional neural network architecture in pytorch for MNIST."""
+
+ def __init__(self, num_classes: int = 10):
+ super().__init__()
+
+ self.num_classes = num_classes
+
+ self.conv1 = nn.Conv2d(1, 16, (3, 3), padding=0)
+ self.conv2 = nn.Conv2d(16, 32, (3, 3), padding=0)
+ self.bn = nn.BatchNorm2d(32)
+ self.fc1 = nn.Linear(32 * 16, 32)
+ self.fc2 = nn.Linear(32, num_classes)
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = F.relu(out)
+ out = self.conv2(out)
+ out = self.bn(out)
+ out = F.relu(out)
+ out = F.adaptive_avg_pool2d(out, (4, 4))
+ out = torch.flatten(out, 1, -1)
+ out = self.fc1(out)
+ out = F.relu(out)
+ return self.fc2(out)
+
+
+model = ComposerClassifier(module=Model(num_classes=10))
# [trainer-start]
# Instantiate the trainer
diff --git a/examples/segmentation/README.md b/examples/segmentation/README.md
deleted file mode 100644
index 8eaa391184..0000000000
--- a/examples/segmentation/README.md
+++ /dev/null
@@ -1,41 +0,0 @@
-# Semantic Segmentation Example
-
-This example illustrates how to train a semantic segmentation model in composer.
-
-## Installation
-
-First, install [Composer](https://github.com/mosaicml/composer) with `pip install mosaicml`. Additionally, our models are pulled from [MMsegmentation](https://github.com/open-mmlab/mmsegmentation), so follow the [MMcv install instructions](https://mmcv.readthedocs.io/en/latest/get_started/installation.html) (which is dependent on your CUDA and PyTorch versions), then install MMsegmentation with `pip install mmsegmentation`.
-
-Alternatively, we have publicly available Docker images to reproduce our results. Use `mosaicml/pytorch_vision:1.12.1_cu116-python3.9-ubuntu20.04` for running on GPUs or `mosaicml/pytorch_vision:1.12.1_cpu-python3.9-ubuntu20.04` for running on CPUs.
-
-## DeepLabv3+ on ADE20k
-
-The `train_deeplabv3_ade20k.py` script trains a DeepLabv3+ model with either a ResNet-50 or ResNet-101 backbone on the ADE20k semantic segmentation benchmark. To download ADE20k locally (~1 GB), specify the `--download` option when running the script, then the dataset will be downloaded data directory path i.e. the first argument.
-
-We designed the script to be hackable, so try our recipes on your own models and datsets!
-### Example configurations
-
-
-
-```bash
-# Downloads ADE20k and does single GPU/CPU training depending on torch.cuda.is_available():
-python train_deeplabv3_ade20k.py /path/to/ade20k --download
-
-# Log experiments to Weights and Biases:
-python train_deeplabv3_ade20k.py /path/to/ade20k --wandb_logger --wandb_entity my_username --wandb_project my_project --run_name my_run_name
-
-# Single/Multi GPU training (infers the number of GPUs available):
-composer train_deeplabv3_ade20k.py /path/to/ade20k
-
-# Manually specify number of GPUs to use:
-composer -n $N_GPUS train_deeplabv3_ade20k.py /path/to/ade20k
-
-# Mild DeepLabv3+ recipe for fastest training to 45.6 mIoU:
-composer train_deeplabv3_ade20k.py /path/to/ade20k/ --recipe_name mild --max_duration 25ep
-
-# Medium DeepLabv3+ recipe for highest mIoU (49.15) with similar training time as baseline:
-composer train_deeplabv3_ade20k.py /path/to/ade20k/ --recipe_name medium --max_duration 90ep
-
-# Hot DeepLabv3+ recipe for highest mIoU (49.83) with a long training schedule:
-composer train_deeplabv3_ade20k.py /path/to/ade20k --recipe_name hot --max_duration 256ep
-```
diff --git a/examples/segmentation/train_deeplabv3_ade20k.py b/examples/segmentation/train_deeplabv3_ade20k.py
deleted file mode 100644
index 90d93aa037..0000000000
--- a/examples/segmentation/train_deeplabv3_ade20k.py
+++ /dev/null
@@ -1,367 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""Example script to train a DeepLabv3+ model on ADE20k for semantic segmentation."""
-
-import argparse
-import logging
-import os
-
-import torch
-import torchvision
-from torch.utils.data import DataLoader
-from torchmetrics import MetricCollection
-from torchvision import transforms
-from torchvision.transforms.functional import InterpolationMode
-
-from composer import DataSpec, Time, Trainer
-from composer.algorithms import EMA, SAM, ChannelsLast, MixUp
-from composer.callbacks import CheckpointSaver, ImageVisualizer, LRMonitor, SpeedMonitor
-from composer.datasets.ade20k import (ADE20k, PadToSize, PhotometricDistoration, RandomCropPair, RandomHFlipPair,
- RandomResizePair)
-from composer.datasets.utils import NormalizationFn, pil_image_collate
-from composer.loggers import WandBLogger
-from composer.loss import DiceLoss, soft_cross_entropy
-from composer.metrics import CrossEntropy, MIoU
-from composer.models import ComposerClassifier
-from composer.models.deeplabv3.model import deeplabv3
-from composer.optim import CosineAnnealingScheduler, DecoupledSGDW
-from composer.utils import dist
-
-logging.basicConfig()
-logging.getLogger().setLevel(logging.INFO)
-
-parser = argparse.ArgumentParser()
-
-# Dataloader command-line arguments
-parser.add_argument('data_dir', help='Path to the directory containing the ImageNet-1k dataset', type=str)
-parser.add_argument('--download',
- help='Use to download ADE20k from the internet and put it in the `data_dir`',
- action='store_true')
-parser.add_argument('--train_resize_size', help='Training image resize size', type=int, default=512)
-parser.add_argument('--eval_resize_size', help='Evaluation image resize size', type=int, default=512)
-parser.add_argument('--train_batch_size', help='Train dataloader per-device batch size', type=int, default=128)
-parser.add_argument('--eval_batch_size', help='Validation dataloader per-device batch size', type=int, default=128)
-
-# Model command-line arguments
-parser.add_argument('--backbone_arch',
- help='Architecture to use for the backbone.',
- default='resnet101',
- choices=['resnet50', 'resnet101'])
-parser.add_argument('--sync_bn',
- help='Use sync BatchNorm. Recommended if the per device microbatch size is below 16',
- action='store_true')
-parser.add_argument('--cross_entropy_weight', help='Weight to scale the cross entropy loss', type=float, default=0.375)
-parser.add_argument('--dice_weight', help='Weight to scale the dice loss', type=float, default=1.125)
-
-# Optimizer command-line arguments
-parser.add_argument('--learning_rate', help='Optimizer learning rate', type=float, default=0.08)
-parser.add_argument('--momentum', help='Optimizer momentum', type=float, default=0.9)
-parser.add_argument('--weight_decay', help='Optimizer weight decay', type=float, default=5.0e-5)
-
-# Save checkpoint command-line arguments
-parser.add_argument('--save_checkpoint_dir',
- help='Directory in which to save model checkpoints',
- type=str,
- default='checkpoints/{run_name}')
-parser.add_argument('--checkpoint_interval',
- help='Frequency to save checkpoints',
- type=Time.from_timestring,
- default='1ep')
-
-# Load checkpoint command-line arguments, assumes resuming from a previous training run (as opposed to fine-tuning)
-parser.add_argument('--load_checkpoint_path', help='Path to the checkpoint to load', type=str)
-
-# Recipes command-line argument
-parser.add_argument('--recipe_name',
- help='Algorithmic recipes to be applied to the trainer',
- choices=['mild', 'medium', 'hot'])
-
-# Logger command-line arguments
-# Note: Only Weights and Biases to minimize arguments. Other loggers can be used by adjusting the script
-parser.add_argument('--wandb_logger', help='Whether or not to log results to Weights and Biases', action='store_true')
-parser.add_argument('--wandb_entity', help='WandB entity name', type=str)
-parser.add_argument('--wandb_project', help='WandB project name', type=str)
-
-parser.add_argument('--image_viz', help='Whether or not to log images using ImageVisualizer', action='store_true')
-
-# Trainer arguments
-parser.add_argument('--device_train_microbatch_size',
- help='Size of train microbatch size if running on GPU',
- default='auto')
-parser.add_argument('--run_name', help='Name of the training run used for checkpointing and logging', type=str)
-parser.add_argument('--seed', help='Random seed', type=int, default=17)
-parser.add_argument('--max_duration',
- help='Duration to train specified as a Time string',
- type=Time.from_timestring,
- default='128ep')
-
-args = parser.parse_args()
-
-IMAGENET_CHANNEL_MEAN = (int(0.485 * 255), int(0.456 * 255), int(0.406 * 255))
-IMAGENET_CHANNEL_STD = (int(0.229 * 255), int(0.224 * 255), int(0.225 * 255))
-
-ADE20K_URL = 'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip'
-ADE20K_FILE = 'ADEChallengeData2016.zip'
-
-
-def _main():
- # Divide batch size by number of devices
- if dist.get_world_size() > 1:
- args.train_batch_size = args.train_batch_size // dist.get_world_size()
- args.eval_batch_size = args.eval_batch_size // dist.get_world_size()
-
- # Train dataset code
- logging.info('Building train dataloader')
-
- if args.download:
- torchvision.datasets.utils.download_and_extract_archive(url=ADE20K_URL,
- download_root=args.data_dir,
- filename=ADE20K_FILE,
- remove_finished=True)
- # Adjust the data_dir to include the extracted directory
- args.data_dir = os.path.join(args.data_dir, 'ADEChallengeData2016')
-
- # Training transforms applied to both the image and target
- train_both_transforms = torch.nn.Sequential(
- RandomResizePair(
- min_scale=0.5,
- max_scale=2.0,
- base_size=(args.train_resize_size, args.train_resize_size),
- ),
- RandomCropPair(
- crop_size=(args.train_resize_size, args.train_resize_size),
- class_max_percent=0.75,
- num_retry=10,
- ),
- RandomHFlipPair(),
- )
-
- # Training transforms applied to the image only
- train_image_transforms = torch.nn.Sequential(
- PhotometricDistoration(
- brightness=32. / 255,
- contrast=0.5,
- saturation=0.5,
- hue=18. / 255,
- ),
- PadToSize(
- size=(args.train_resize_size, args.train_resize_size),
- fill=IMAGENET_CHANNEL_MEAN,
- ),
- )
-
- # Training transforms applied to the target only
- train_target_transforms = PadToSize(size=(args.train_resize_size, args.train_resize_size), fill=0)
-
- # Create ADE20k train dataset
- train_dataset = ADE20k(
- datadir=args.data_dir,
- split='training',
- image_transforms=train_image_transforms,
- target_transforms=train_target_transforms,
- both_transforms=train_both_transforms,
- )
-
- # Create ADE20k train dataloader
-
- train_sampler = None
- if dist.get_world_size():
- # Nifty function to instantiate a PyTorch DistributedSampler based on your hardware setup
- train_sampler = dist.get_sampler(train_dataset, drop_last=True, shuffle=True)
-
- train_dataloader = DataLoader(
- train_dataset,
- batch_size=args.train_batch_size,
- num_workers=8,
- pin_memory=True,
- drop_last=True, # Prevents using a smaller batch at the end of an epoch
- sampler=train_sampler,
- collate_fn=pil_image_collate,
- persistent_workers=True,
- )
-
- # DataSpec enables image normalization to be performed on-GPU, marginally relieving dataloader bottleneck
- train_dataspec = DataSpec(dataloader=train_dataloader,
- device_transforms=NormalizationFn(mean=IMAGENET_CHANNEL_MEAN,
- std=IMAGENET_CHANNEL_STD,
- ignore_background=True))
- logging.info('Built train dataloader\n')
-
- # Validation dataset code
- logging.info('Building evaluation dataloader')
-
- # Validation image and target transformations
- image_transforms = transforms.Resize(size=(args.eval_resize_size, args.eval_resize_size),
- interpolation=InterpolationMode.BILINEAR)
- target_transforms = transforms.Resize(size=(args.eval_resize_size, args.eval_resize_size),
- interpolation=InterpolationMode.NEAREST)
-
- # Create ADE20k validation dataset
- val_dataset = ADE20k(datadir=args.data_dir,
- split='validation',
- both_transforms=None,
- image_transforms=image_transforms,
- target_transforms=target_transforms)
-
- #Create ADE20k validation dataloader
-
- val_sampler = None
- if dist.get_world_size():
- # Nifty function to instantiate a PyTorch DistributedSampler based on your hardware
- val_sampler = dist.get_sampler(val_dataset, drop_last=False, shuffle=False)
-
- val_dataloader = DataLoader(
- val_dataset,
- batch_size=args.eval_batch_size,
- num_workers=8,
- pin_memory=True,
- drop_last=False,
- sampler=val_sampler,
- collate_fn=pil_image_collate,
- persistent_workers=True,
- )
-
- # DataSpec enables image normalization to be performed on-GPU, marginally relieving dataloader bottleneck
- val_dataspec = DataSpec(dataloader=val_dataloader,
- device_transforms=NormalizationFn(mean=IMAGENET_CHANNEL_MEAN,
- std=IMAGENET_CHANNEL_STD,
- ignore_background=True))
- logging.info('Built validation dataset\n')
-
- logging.info('Building Composer DeepLabv3+ model')
-
- # Create a DeepLabv3+ model
- model = deeplabv3(
- num_classes=150,
- backbone_arch=args.backbone_arch,
- backbone_weights='IMAGENET1K_V2',
- sync_bn=args.sync_bn,
- use_plus=True,
- )
-
- # Initialize the classifier head only since the backbone uses pre-trained weights
- def weight_init(module: torch.nn.Module):
- if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
- torch.nn.init.kaiming_normal_(module.weight)
- if isinstance(module, torch.nn.BatchNorm2d):
- torch.nn.init.ones_(module.weight)
- torch.nn.init.zeros_(module.bias)
-
- model.classifier.apply(weight_init) # type: ignore Does not recognize classifier as a torch.nn.Module
-
- # Loss function to use during training
- # This ignores index -1 since the NormalizationFn transformation sets the background class to -1
- dice_loss_fn = DiceLoss(softmax=True, batch=True, ignore_absent_classes=True)
-
- def combo_loss(output, target):
- loss = {}
- loss['cross_entropy'] = soft_cross_entropy(output, target, ignore_index=-1)
- loss['dice'] = dice_loss_fn(output, target)
- loss['total'] = args.cross_entropy_weight * loss['cross_entropy'] + args.dice_weight * loss['dice']
- return loss
-
- # Training and Validation metrics to log throughout training
- train_metrics = MetricCollection([CrossEntropy(ignore_index=-1), MIoU(num_classes=150, ignore_index=-1)])
- val_metrics = MetricCollection([CrossEntropy(ignore_index=-1), MIoU(num_classes=150, ignore_index=-1)])
-
- # Create a ComposerClassifier using the model, loss function, and metrics
- composer_model = ComposerClassifier(module=model,
- train_metrics=train_metrics,
- val_metrics=val_metrics,
- loss_fn=combo_loss)
-
- logging.info('Built Composer DeepLabv3+ model\n')
-
- logging.info('Building optimizer and learning rate scheduler')
- # Optimizer
- optimizer = DecoupledSGDW(composer_model.parameters(),
- lr=args.learning_rate,
- momentum=args.momentum,
- weight_decay=args.weight_decay)
-
- # Only use a LR schedule if no recipe is specified or if the hot recipe was specified
- lr_scheduler = None
- if args.recipe_name is None or args.recipe_name == 'hot':
- lr_scheduler = CosineAnnealingScheduler()
-
- logging.info('Built optimizer and learning rate scheduler')
-
- logging.info('Building callbacks: SpeedMonitor, LRMonitor, and CheckpointSaver')
- speed_monitor = SpeedMonitor(window_size=50) # Measures throughput as samples/sec and tracks total training time
- lr_monitor = LRMonitor() # Logs the learning rate
-
- # Callback for checkpointing
- checkpoint_saver = CheckpointSaver(folder=args.save_checkpoint_dir, save_interval=args.checkpoint_interval)
- logging.info('Built callbacks: SpeedMonitor, LRMonitor, and CheckpointSaver\n')
-
- # Recipes for training DeepLabv3+ on ImageNet in order of increasing training time and accuracy
- # To learn about individual methods, check out "Methods Overview" in our documentation: https://docs.mosaicml.com/
- logging.info('Building algorithm recipes')
- if args.recipe_name == 'mild':
- algorithms = [
- ChannelsLast(),
- EMA(half_life='1000ba', update_interval='10ba'),
- ]
- elif args.recipe_name == 'medium':
- algorithms = [
- ChannelsLast(),
- EMA(half_life='1000ba', update_interval='10ba'),
- SAM(rho=0.3, interval=2),
- MixUp(alpha=0.2),
- ]
- elif args.recipe_name == 'hot':
- algorithms = [
- ChannelsLast(),
- EMA(half_life='2000ba', update_interval='1ba'),
- SAM(rho=0.3, interval=1),
- MixUp(alpha=0.5),
- ]
- else:
- algorithms = None
- logging.info('Built algorithm recipes\n')
-
- # Weight and Biases logger if specified in commandline
- logger = None
- if args.wandb_logger:
- logging.info('Building Weights and Biases logger')
- if args.wandb_entity is None:
- raise ValueError('Please specify --wandb_entity argument')
- if args.wandb_project is None:
- raise ValueError('Please specify --wandb_project argument')
- logger = WandBLogger(entity=args.wandb_entity, project=args.wandb_project)
- logging.info('Built Weights and Biases logger')
-
- callbacks = [speed_monitor, lr_monitor, checkpoint_saver]
- if args.image_viz:
- callbacks.append(ImageVisualizer(mode='segmentation'))
- # Create the Trainer!
- logging.info('Building Trainer')
- device = 'gpu' if torch.cuda.is_available() else 'cpu'
- precision = 'amp' if device == 'gpu' else 'fp32' # Mixed precision for fast training when using a GPU
- device_train_microbatch_size = 'auto' if device == 'gpu' else args.device_train_microbatch_size # If on GPU, use 'auto' gradient accumulation
- trainer = Trainer(run_name=args.run_name,
- model=composer_model,
- train_dataloader=train_dataspec,
- eval_dataloader=val_dataspec,
- eval_interval='1ep',
- optimizers=optimizer,
- schedulers=lr_scheduler,
- algorithms=algorithms,
- loggers=logger,
- max_duration=args.max_duration,
- callbacks=callbacks,
- load_path=args.load_checkpoint_path,
- device=device,
- precision=precision,
- device_train_microbatch_size=device_train_microbatch_size,
- seed=args.seed)
- logging.info('Built Trainer\n')
-
- # Start training!
- logging.info('Train!')
- trainer.fit()
-
-
-if __name__ == '__main__':
- _main()
diff --git a/pyproject.toml b/pyproject.toml
index 342c9b3d7e..1583440640 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -82,18 +82,15 @@ reportUnusedCoroutine = "error"
# Pytest
[tool.pytest.ini_options]
# By default, do not run gpu, vision, docs, notebook, or daily tests
-addopts = "--codeblocks --strict-markers -m 'not gpu and not vision and not doctest and not daily and not remote'"
+addopts = "--codeblocks --strict-markers -m 'not gpu and not doctest and not daily and not remote'"
markers = [
- # !!!!!!!!!!!IMPORTANT!!!!!!!!!: when updating the markers, also make sure to update meta.yaml
# Tests that require a world_size of two should be annotated with `@pytest.mark.world_size(2)`.
# If not specified, the test will be assumed to have a world-size of one, which is
# equivalent to `@pytest.mark.world_size(1)`
"world_size(val)",
# Tests that require a gpu should be annotated with `@pytest.mark.gpu`
"gpu",
- # Whether the test should run in a container based on the vision dockerimage, which contains ffcv and opencv
- "vision",
# Tests which are run as part of the documentation build
"doctest",
# Should be run during daily regression
@@ -151,6 +148,12 @@ filterwarnings = [
'''ignore:torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead:UserWarning''',
# Ignore torch sharded tensor deprecated warnings
'''ignore:Please use DTensor instead and we are deprecating ShardedTensor.:UserWarning''',
+ # Ignore torch pytree deprecated warnings
+ '''ignore:torch.utils._pytree._register_pytree_node is deprecated.*:UserWarning''',
+ # Ignore autograd kernel warning inside DeepSpeed
+ '''ignore:.*an autograd kernel was not registered to the Autograd key.*:UserWarning''',
+ # Ignore save_state_dict / load_state_dict deprecation warnings
+ '''ignore:'.*_state_dict' is deprecated and will be removed in future versions.*:UserWarning'''
]
# Coverage
diff --git a/scripts/ffcv/create_ffcv_datasets.py b/scripts/ffcv/create_ffcv_datasets.py
deleted file mode 100644
index 190974c762..0000000000
--- a/scripts/ffcv/create_ffcv_datasets.py
+++ /dev/null
@@ -1,148 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-"""Helper utilities to create FFCV datasets."""
-
-import logging
-import os
-import sys
-import textwrap
-from argparse import ArgumentParser
-from io import BytesIO
-from typing import Tuple
-
-import numpy as np
-import torch
-from PIL import Image
-from torch.utils.data import Subset
-from torchvision import transforms
-from torchvision.datasets import CIFAR10, ImageFolder
-from tqdm import tqdm
-
-from composer.datasets.ffcv_utils import write_ffcv_dataset
-
-log = logging.getLogger(__name__)
-
-
-def _get_parser():
- parser = ArgumentParser(description='Utility for converting datasets to ffcv format.')
-
- parser.add_argument('--dataset',
- type=str,
- default='cifar10',
- choices=['cifar10', 'imagenet1k'],
- help=textwrap.dedent("""\
- Dataset to use. Default: cifar10"""))
- parser.add_argument('--remote',
- type=str,
- help=textwrap.dedent("""\
- Remote directory (S3 or local filesystem) where dataset is stored., Example: s3://my-s3-bucket-name"""
- ))
- parser.add_argument('--local',
- type=str,
- default=None,
- help=textwrap.dedent("""\
- Local filesystem directory where dataset is cached during operation. Default: None"""))
- parser.add_argument('--split',
- type=str,
- default='train',
- choices=['train', 'val'],
- help=textwrap.dedent("""\
- Split to use. Default: train"""))
-
- parser.add_argument('--datadir',
- type=str,
- default=None,
- help=textwrap.dedent("""\
- Location of the dataset. Default: None"""))
-
- parser.add_argument('--download',
- type=bool,
- default=False,
- help=textwrap.dedent("""\
- Download the dataset if possible. Default: False"""))
-
- parser.add_argument('--write_path',
- type=str,
- default=None,
- help=textwrap.dedent("""\
- File path to use for writing the dataset. Default: /tmp/_.ffcv"""))
-
- parser.add_argument('--write_mode',
- type=str,
- default='proportion',
- choices=['raw', 'jpg', 'smart', 'proportion'],
- help=textwrap.dedent("""\
- Write mode to use. raw is uint8 values, jpg is jpeg compressed images, smart is
- compressing based on image size and proportion is according to the given
- compress_probability. Default: proportion"""))
-
- parser.add_argument('--max_resolution', type=int, default=500, help='Max resoultion for images.')
-
- parser.add_argument('--num_workers', type=int, default=64, help='Number of workers to use.')
-
- parser.add_argument('--chunk_size', type=int, default=100, help='Chunk size to use.')
-
- parser.add_argument('--jpeg_quality', type=int, default=90, help='Quality of jpeg.')
-
- parser.add_argument('--subset', type=int, default=-1, help='Only use a subset of dataset.')
-
- parser.add_argument('--compress_probability',
- type=float,
- required=False,
- default=0.50,
- help='Compress the given fraction of images to jpeg while writing the ffcv dataset.')
- return parser
-
-
-def _parse_args():
- parser = _get_parser()
-
- args = parser.parse_args()
-
- if args.datadir is not None:
- log.info(f'Will read from local directory: {args.datadir}.')
- else:
- if args.local is None:
- args.local = f'/tmp/mds-cache/mds-{args.dataset}/'
-
- if args.remote.startswith('s3://'):
- log.info(f'Will read from remote: {args.remote}.')
- else:
- log.info(f'Will read from local: {args.remote}.')
-
- if args.write_path is None:
- args.write_path = f'/tmp/{args.dataset}_{args.split}.ffcv'
-
- if os.path.exists(args.write_path):
- log.error(f'Destination already exists: {args.write_path}')
- sys.exit(-1)
-
- return args
-
-
-def _main():
- args = _parse_args()
-
- if args.dataset == 'cifar10':
- dataset = CIFAR10(root=args.datadir, train=(args.split == 'train'), download=args.download)
- elif args.dataset == 'imagenet1k':
- dataset = ImageFolder(os.path.join(args.datadir, args.split))
- else:
- raise ValueError(f'Unsupported dataset: {args.dataset}. Checkout the list of supported datasets with -h')
-
- if args.subset > 0:
- dataset = Subset(dataset, range(args.subset))
-
- write_ffcv_dataset(dataset=dataset,
- write_path=args.write_path,
- max_resolution=args.max_resolution,
- num_workers=args.num_workers,
- write_mode=args.write_mode,
- compress_probability=args.compress_probability,
- jpeg_quality=args.jpeg_quality,
- chunk_size=args.chunk_size)
-
-
-if __name__ == '__main__':
- sys.exit(_main())
diff --git a/setup.py b/setup.py
index 7322bdc49e..6600f716a7 100644
--- a/setup.py
+++ b/setup.py
@@ -76,10 +76,10 @@ def package_files(prefix: str, directory: str, extension: str):
install_requires = [
'pyyaml>=6.0,<7',
'tqdm>=4.62.3,<5',
- 'torchmetrics>=0.10.0,<1.1',
+ 'torchmetrics>=0.10.0,<1.3.1',
'torch_optimizer>=0.3.0,<0.4',
- 'torchvision>=0.13.1,<0.19',
- 'torch>=1.13.1,<2.2.1',
+ 'torchvision>=0.13.1,<0.20', # TODO: Tighten before release
+ 'torch>=2.0.1,<2.3.1', # TODO: Tighten before release
'requests>=2.26.0,<3',
'numpy>=1.21.5,<1.27.0',
'psutil>=5.8.0,<6',
@@ -88,7 +88,7 @@ def package_files(prefix: str, directory: str, extension: str):
'py-cpuinfo>=8.0.0,<10',
'packaging>=21.3.0,<23',
'importlib-metadata>=5.0.0,<7',
- 'mosaicml-cli>=0.5.25,<0.6',
+ 'mosaicml-cli>=0.5.25,<0.7',
]
extra_deps = {}
@@ -100,14 +100,13 @@ def package_files(prefix: str, directory: str, extension: str):
# Should manually update dependency versions occassionally.
'custom_inherit==2.4.1',
'junitparser==3.1.1',
- 'coverage[toml]==7.3.4',
+ 'coverage[toml]==7.4.1',
'fasteners==0.18', # object store tests require fasteners
'pytest==7.4.4',
- 'toml==0.10.2',
'ipython==8.11.0',
- 'ipykernel==6.28.0',
+ 'ipykernel==6.29.2',
'jupyter==1.0.0',
- 'yamllint==1.33.0',
+ 'yamllint==1.34.0',
'recommonmark==0.7.1',
'sphinx==4.4.0',
'pre-commit>=3.4.0,<4',
@@ -117,6 +116,11 @@ def package_files(prefix: str, directory: str, extension: str):
'sphinx_markdown_tables==0.0.17',
'sphinx-argparse==0.4.0',
'sphinxcontrib.katex==0.9.6',
+ 'sphinxcontrib-applehelp==1.0.0',
+ 'sphinxcontrib-devhelp==1.0.0',
+ 'sphinxcontrib-htmlhelp==2.0.0',
+ 'sphinxcontrib-serializinghtml==1.1.5',
+ 'sphinxcontrib-qthelp==1.0.0',
'sphinxext.opengraph==0.9.1',
'sphinxemoji==0.2.0',
'furo==2022.9.29',
@@ -130,7 +134,7 @@ def package_files(prefix: str, directory: str, extension: str):
'nbsphinx==0.9.1',
'pandoc==2.3',
'pypandoc==1.12',
- 'GitPython==3.1.40',
+ 'GitPython==3.1.41',
'moto[s3]>=4.0.1,<5',
'mock-ssh-server==0.9.1',
'cryptography==41.0.5',
@@ -138,10 +142,6 @@ def package_files(prefix: str, directory: str, extension: str):
'setuptools<=59.5.0',
]
-extra_deps['health_checker'] = {
- 'pynvml>=11.5.0,<12',
-}
-
extra_deps['system_metrics_monitor'] = {
'pynvml>=11.5.0,<12',
}
@@ -163,21 +163,12 @@ def package_files(prefix: str, directory: str, extension: str):
'comet_ml>=3.31.12,<4.0.0',
]
-extra_deps['tensorboard'] = [
- 'tensorboard>=2.9.1,<3.0.0',
-]
-
-extra_deps['unet'] = [
- 'monai>=0.9.1,<1.4',
- 'scikit-learn>=1.0.1,<2',
+extra_deps['neptune'] = [
+ 'neptune>=1.6.2,<2.0.0',
]
-extra_deps['vit'] = [
- 'vit_pytorch==1.6.1',
-]
-
-extra_deps['timm'] = [
- 'timm>=0.5.4,<0.6',
+extra_deps['tensorboard'] = [
+ 'tensorboard>=2.9.1,<3.0.0',
]
extra_deps['coco'] = [
@@ -185,10 +176,14 @@ def package_files(prefix: str, directory: str, extension: str):
]
extra_deps['nlp'] = [
- 'transformers>=4.11,<4.37,!=4.34.0',
+ 'transformers>=4.11,<4.38,!=4.34.0',
'datasets>=2.4,<3',
]
+extra_deps['peft'] = [
+ 'peft>=0.7.0,<0.8',
+]
+
extra_deps['sentencepiece'] = [
'protobuf<3.21',
'sentencepiece==0.1.99',
@@ -229,7 +224,7 @@ def package_files(prefix: str, directory: str, extension: str):
extra_deps['pandas'] = ['pandas>=2.0.0,<3.0']
-extra_deps['databricks'] = ['databricks-sdk>=0.15.0,<1.0']
+extra_deps['databricks'] = ['databricks-sdk==0.18.0']
extra_deps['all'] = {dep for deps in extra_deps.values() for dep in deps}
@@ -258,9 +253,9 @@ def package_files(prefix: str, directory: str, extension: str):
packages=setuptools.find_packages(exclude=['docker*', 'examples*', 'scripts*', 'tests*']),
classifiers=[
'Programming Language :: Python :: 3',
- 'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
+ 'Programming Language :: Python :: 3.11',
],
install_requires=install_requires,
entry_points={
@@ -272,7 +267,7 @@ def package_files(prefix: str, directory: str, extension: str):
},
extras_require=extra_deps,
dependency_links=['https://developer.download.nvidia.com/compute/redist'],
- python_requires='>=3.8',
+ python_requires='>=3.9',
ext_package='composer',
cmdclass={'develop': develop})
diff --git a/tests/algorithms/algorithm_settings.py b/tests/algorithms/algorithm_settings.py
index 940ca040f2..91ecf2dac2 100644
--- a/tests/algorithms/algorithm_settings.py
+++ b/tests/algorithms/algorithm_settings.py
@@ -21,12 +21,11 @@
LabelSmoothing, LayerFreezing, LowPrecisionGroupNorm, LowPrecisionLayerNorm, MixUp,
NoOpModel, ProgressiveResizing, RandAugment, SelectiveBackprop, SeqLengthWarmup,
SqueezeExcite, StochasticDepth, WeightStandardization)
-from composer.models import composer_resnet
from composer.models.base import ComposerModel
from composer.utils import dist
from tests.common import get_module_subclasses
from tests.common.datasets import RandomImageDataset, SimpleDataset, dummy_bert_lm_dataloader, dummy_gpt_lm_dataloader
-from tests.common.models import (SimpleConvModel, SimpleModelWithDropout, configure_tiny_bert_hf_model,
+from tests.common.models import (SimpleConvModel, SimpleModelWithDropout, composer_resnet, configure_tiny_bert_hf_model,
configure_tiny_gpt2_hf_model)
simple_bert_settings = {
diff --git a/tests/algorithms/test_algorithm_resumption.py b/tests/algorithms/test_algorithm_resumption.py
index d1fb4e2c40..9f243caeae 100644
--- a/tests/algorithms/test_algorithm_resumption.py
+++ b/tests/algorithms/test_algorithm_resumption.py
@@ -57,7 +57,7 @@ def test_algorithm_resumption(
'save_filename': 'ep{epoch}-rank{rank}',
'save_interval': '1ep',
'train_subset_num_batches': 2,
- 'precision': 'amp_fp16',
+ 'precision': 'amp_bf16',
}
train_dataloader = get_alg_dataloader(alg_cls) if world_size == 1 else get_alg_dataloader(alg_cls, multigpu=True)
# train model once, saving checkpoints every epoch
@@ -117,6 +117,7 @@ def test_algorithm_resumption(
def _assert_checkpoints_equal(file1, file2):
+ # TODO: consider merging with _assert_checkpoints_equivalent
checkpoint1 = torch.load(file1)
checkpoint2 = torch.load(file2)
@@ -136,6 +137,10 @@ def _assert_checkpoints_equal(file1, file2):
del checkpoint1['state']['run_name']
del checkpoint2['state']['run_name']
+ # Remove all saved checkpoints to timestamp (accumulates between runs)
+ del checkpoint1['state']['callbacks']['CheckpointSaver']['all_saved_checkpoints_to_timestamp']
+ del checkpoint2['state']['callbacks']['CheckpointSaver']['all_saved_checkpoints_to_timestamp']
+
# Remove algorithm representations which are memory addresses
for i, algo_info in enumerate(checkpoint1['state']['algorithms']):
if '0x' in algo_info[1]['repr']:
diff --git a/tests/algorithms/test_alibi.py b/tests/algorithms/test_alibi.py
index 81617a5ade..c33bd58bff 100644
--- a/tests/algorithms/test_alibi.py
+++ b/tests/algorithms/test_alibi.py
@@ -93,7 +93,7 @@ def test_registry(caplog):
from composer.algorithms.alibi.attention_surgery_functions import policy_registry
@policy_registry.register(torch.nn.Linear)
- def zero_linear_weights( # pyright: reportUnusedFunction = none
+ def zero_linear_weights( # pyright: ignore[reportUnusedFunction]
module: torch.nn.Module, idx: int, max_sequence_length: int) -> torch.nn.Module:
assert isinstance(module, torch.nn.Linear)
old_weight = getattr(module, 'weight')
diff --git a/tests/algorithms/test_colout.py b/tests/algorithms/test_colout.py
index 9e71d2554c..007bd43fd1 100644
--- a/tests/algorithms/test_colout.py
+++ b/tests/algorithms/test_colout.py
@@ -1,8 +1,8 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
+from __future__ import annotations
import functools
-from typing import Tuple
import numpy as np
import pytest
@@ -28,7 +28,8 @@ def verify_shape_image(orig: Image.Image, new: Image.Image, p_row: float, p_col:
assert (H_n, W_n) == (H_t, W_t), f'Image shape mismatch: {(H_n, W_n)} != {(H_t, W_t)}'
-def verify_shape_image_pair(orig_sample: Tuple[Image.Image, Image.Image], new_sample: Tuple[Image.Image, Image.Image],
+def verify_shape_image_pair(orig_sample: tuple[Image.Image, Image.Image],
+ new_sample: tuple[torch.Tensor, torch.Tensor] | tuple[Image.Image, Image.Image],
p_row: float, p_col: float):
"""Verify the shape of a pair of transformed PIL images."""
H_o, W_o = orig_sample[0].height, orig_sample[0].width
@@ -50,8 +51,8 @@ def verify_shape_tensor(orig: torch.Tensor, new: torch.Tensor, p_row: float, p_c
assert new.shape == (C, H_t, W_t), f'Image tensor shape mismatch: {new.shape} != {(C, H_t, W_t)}'
-def verify_shape_tensor_pair(orig_sample: Tuple[torch.Tensor, torch.Tensor],
- new_sample: Tuple[torch.Tensor, torch.Tensor], p_row: float, p_col: float) -> None:
+def verify_shape_tensor_pair(orig_sample: tuple[torch.Tensor, torch.Tensor],
+ new_sample: tuple[torch.Tensor, torch.Tensor], p_row: float, p_col: float) -> None:
"""Verify the shape of a transformed image tensor."""
C, H_o, W_o = orig_sample[0].shape
@@ -72,8 +73,8 @@ def verify_shape_batch(orig: torch.Tensor, new: torch.Tensor, p_row: float, p_co
assert new.shape == (N, C, H_t, W_t), f'Image batch shape mismatch: {new.shape} != {(N, C, H_t, W_t)}'
-def verify_shape_batch_pair(orig_sample: Tuple[torch.Tensor, torch.Tensor],
- new_sample: Tuple[torch.Tensor, torch.Tensor], p_row: float, p_col: float) -> None:
+def verify_shape_batch_pair(orig_sample: tuple[torch.Tensor, torch.Tensor],
+ new_sample: tuple[torch.Tensor, torch.Tensor], p_row: float, p_col: float) -> None:
"""Verify the shape of a transformed batch of images."""
N, C, H_o, W_o = orig_sample[0].shape
@@ -163,7 +164,7 @@ def test_image_pair_drop_size(self, fake_image: Image.Image, p_row: float, p_col
transform = ColOutTransform(p_row, p_col)
orig_sample = (fake_image, fake_image)
new_sample = transform(orig_sample)
- assert isinstance(new_sample, Tuple)
+ assert isinstance(new_sample, tuple)
verify_shape_image_pair(orig_sample, new_sample, p_row, p_col)
@pytest.mark.parametrize('W', [48])
@@ -228,7 +229,7 @@ def test_batch_pair_drop_size(self, fake_image_batch: torch.Tensor, p_row: float
colout = functools.partial(colout_batch, p_row=p_row, p_col=p_col)
sample = (fake_image_batch, fake_image_batch)
new_batch = colout(sample)
- assert isinstance(new_batch, Tuple) and isinstance(new_batch[0], torch.Tensor) and isinstance(
+ assert isinstance(new_batch, tuple) and isinstance(new_batch[0], torch.Tensor) and isinstance(
new_batch[1], torch.Tensor)
verify_shape_batch_pair(sample, new_batch, p_row, p_col)
diff --git a/tests/algorithms/test_gradient_clipping.py b/tests/algorithms/test_gradient_clipping.py
index fe06fa188b..57e71c35f5 100644
--- a/tests/algorithms/test_gradient_clipping.py
+++ b/tests/algorithms/test_gradient_clipping.py
@@ -5,7 +5,6 @@
import pytest
import torch
-from packaging import version
from torch import nn
import composer.algorithms.gradient_clipping.gradient_clipping as gc_module
@@ -13,7 +12,6 @@
from composer.algorithms.gradient_clipping.gradient_clipping import _apply_agc, _get_clipped_gradient_coeff
from composer.core import Engine, State
from composer.core.event import Event
-from composer.utils.misc import using_torch_2
from tests.common import world_size
from tests.common.datasets import dummy_tiny_bert_classification_batch, dummy_transformer_classifier_batch
from tests.common.models import SimpleTransformerClassifier, configure_tiny_bert_config
@@ -29,9 +27,9 @@ def simple_model_with_grads():
# Force wrap every module in FSDP, to allow for testing FSDP
# gradient clipping properly.
for module in model:
- module._fsdp_wrap = True
+ module._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
- model._fsdp_wrap = True
+ model._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
o = model(x)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(o, y)
@@ -64,7 +62,7 @@ def forward(self, x):
# Force wrap every module in FSDP, to allow for testing FSDP
# gradient clipping properly.
for layer in model.modules():
- layer._fsdp_wrap = True
+ layer._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
o = model(x)
loss_fn = nn.CrossEntropyLoss()
@@ -79,7 +77,7 @@ def simple_transformer_model_with_grads():
# Force wrap every module in FSDP, to allow for testing FSDP
# gradient clipping properly.
for layer in model.modules():
- layer._fsdp_wrap = True
+ layer._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
x = dummy_transformer_classifier_batch(num_classes=3)
o = model(x)
@@ -104,7 +102,7 @@ def hf_model_with_grads():
# Force wrap every module in FSDP, to allow for testing FSDP
# gradient clipping properly.
for layer in model.modules():
- layer._fsdp_wrap = True
+ layer._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
x = dummy_tiny_bert_classification_batch(num_classes=3)
o = model(x).logits
@@ -193,27 +191,17 @@ def test_gradient_clipping_algorithm_with_deepspeed_enabled(
apply_gc_fn.assert_not_called()
-if not using_torch_2():
+def _auto_wrap_policy(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool:
+ if recurse:
+ return True
- def _auto_wrap_policy(module: torch.nn.Module, recurse: bool, unwrapped_params: int) -> bool: # type: ignore
- if recurse:
- return True
- if hasattr(module, '_fsdp_wrap'):
- return bool(module._fsdp_wrap)
+ # With Torch 2.0, there is a bug that emits a nasty warning if you wrap a module with no parameters
+ if len(list(module.parameters())) == 0:
return False
-else:
- def _auto_wrap_policy(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool:
- if recurse:
- return True
-
- # With Torch 2.0, there is a bug that emits a nasty warning if you wrap a module with no parameters
- if len(list(module.parameters())) == 0:
- return False
-
- if hasattr(module, '_fsdp_wrap'):
- return bool(module._fsdp_wrap)
- return False
+ if hasattr(module, '_fsdp_wrap'):
+ return bool(module._fsdp_wrap)
+ return False
@pytest.mark.parametrize('model_with_grads', [
@@ -223,8 +211,6 @@ def _auto_wrap_policy(module: torch.nn.Module, recurse: bool, nonwrapped_numel:
hf_model_with_grads
])
@pytest.mark.parametrize('clipping_type', ['norm', 'value'])
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='requires PyTorch 1.13 or higher')
@pytest.mark.gpu
@world_size(2)
def test_gradient_clipping_algorithm_with_fsdp_enabled_does_not_error(
@@ -241,13 +227,10 @@ def test_gradient_clipping_algorithm_with_fsdp_enabled_does_not_error(
clipping_threshold = 0.1191
state = dummy_state
- torch_2_kwargs = {}
- if using_torch_2():
- torch_2_kwargs['use_orig_params'] = True
state.model = FullyShardedDataParallel(model,
auto_wrap_policy=_auto_wrap_policy,
device_id=torch.cuda.current_device(),
- **torch_2_kwargs)
+ use_orig_params=True)
state.algorithms = [GradientClipping(clipping_type=clipping_type, clipping_threshold=clipping_threshold)]
logger = Mock()
diff --git a/tests/algorithms/test_required_on_load.py b/tests/algorithms/test_required_on_load.py
index 3844a57084..defaaa4389 100644
--- a/tests/algorithms/test_required_on_load.py
+++ b/tests/algorithms/test_required_on_load.py
@@ -9,12 +9,13 @@
import pytest
import torch
+from packaging import version
from composer import Trainer, algorithms
from composer.callbacks import CheckpointSaver
from composer.core import Algorithm, Event, Time, TimeUnit # type: ignore imports used in `eval(representation)`
-from composer.models import ComposerClassifier, ComposerModel, composer_resnet
-from tests.common import ConvModel, SimpleConvModel
+from composer.models import ComposerClassifier, ComposerModel
+from tests.common import ConvModel, SimpleConvModel, composer_resnet
def initialize_algorithm(algo_cls: Type):
@@ -163,14 +164,20 @@ def test_autoload(algo_name: str, load_weights_only: bool, already_added: bool,
context = pytest.warns(UserWarning, match='Automatically adding required_on_load algorithm*')
# Excluding some algorithms leads to errors when loading
elif exclude:
- if algo_name in ['Factorize', 'SqueezeExcite']:
- context = pytest.raises(
- ValueError,
- match=
- "loaded state dict contains a parameter group that doesn't match the size of optimizer's group",
- )
- elif algo_name == 'Alibi':
- context = pytest.raises(RuntimeError)
+ if version.parse(torch.__version__) > version.parse('2.2.9'):
+ if algo_name in [
+ 'Alibi', 'BlurPool', 'Factorize', 'GatedLinearUnits', 'GhostBatchNorm', 'SqueezeExcite'
+ ]:
+ context = pytest.raises(KeyError) # Optimizer loading is strict
+ else:
+ if algo_name in ['Factorize', 'SqueezeExcite']:
+ context = pytest.raises(
+ ValueError,
+ match=
+ "loaded state dict contains a parameter group that doesn't match the size of optimizer's group",
+ )
+ elif algo_name == 'Alibi':
+ context = pytest.raises(RuntimeError)
with context:
trainer2 = Trainer(
diff --git a/tests/algorithms/test_stochastic_depth.py b/tests/algorithms/test_stochastic_depth.py
index 23c21bd816..2ec267756a 100644
--- a/tests/algorithms/test_stochastic_depth.py
+++ b/tests/algorithms/test_stochastic_depth.py
@@ -14,8 +14,8 @@
from composer.algorithms.stochastic_depth.stochastic_layers import make_resnet_bottleneck_stochastic
from composer.core import Event, State
from composer.core.time import TimeUnit
-from composer.models import composer_resnet
from composer.utils import module_surgery
+from tests.common import composer_resnet
@pytest.fixture()
diff --git a/tests/callbacks/callback_settings.py b/tests/callbacks/callback_settings.py
index ef9fe12187..f6065c1863 100644
--- a/tests/callbacks/callback_settings.py
+++ b/tests/callbacks/callback_settings.py
@@ -3,6 +3,7 @@
import os
from typing import Any, Dict, List, Tuple, Type
+from unittest.mock import MagicMock
import pytest
from torch.utils.data import DataLoader
@@ -11,11 +12,11 @@
import composer.loggers
import composer.profiler
from composer import Callback
-from composer.callbacks import (EarlyStopper, ExportForInferenceCallback, FreeOutputs, Generate, HealthChecker,
- ImageVisualizer, MemoryMonitor, MLPerfCallback, SpeedMonitor, SystemMetricsMonitor,
- ThresholdStopper)
-from composer.loggers import (CometMLLogger, ConsoleLogger, LoggerDestination, MLFlowLogger, ProgressBarLogger,
- RemoteUploaderDownloader, TensorboardLogger, WandBLogger)
+from composer.callbacks import (EarlyStopper, ExportForInferenceCallback, FreeOutputs, Generate, ImageVisualizer,
+ MemoryMonitor, MemorySnapshot, MLPerfCallback, OOMObserver, SpeedMonitor,
+ SystemMetricsMonitor, ThresholdStopper)
+from composer.loggers import (CometMLLogger, ConsoleLogger, LoggerDestination, MLFlowLogger, NeptuneLogger,
+ ProgressBarLogger, RemoteUploaderDownloader, TensorboardLogger, WandBLogger)
from composer.models.base import ComposerModel
from composer.utils import dist
from composer.utils.device import get_device
@@ -76,6 +77,13 @@
except ImportError:
_PYNMVL_INSTALLED = False
+try:
+ import neptune
+ _NEPTUNE_INSTALLED = True
+ del neptune # unused
+except ImportError:
+ _NEPTUNE_INSTALLED = False
+
_callback_kwargs: Dict[Type[Callback], Dict[str, Any],] = {
Generate: {
'prompts': ['a', 'b', 'c'],
@@ -115,6 +123,13 @@
SpeedMonitor: {
'window_size': 1,
},
+ NeptuneLogger: {
+ 'mode': 'debug',
+ },
+ composer.profiler.Profiler: {
+ 'trace_handlers': [MagicMock()],
+ 'schedule': composer.profiler.cyclic_schedule(),
+ }
}
_callback_marks: Dict[Type[Callback], List[pytest.MarkDecorator],] = {
@@ -128,6 +143,14 @@
pytest.mark.filterwarnings(
r'ignore:The memory monitor only works on CUDA devices, but the model is on cpu:UserWarning')
],
+ MemorySnapshot: [
+ pytest.mark.filterwarnings(
+ r'ignore:The memory snapshot only works on CUDA devices, but the model is on cpu:UserWarning')
+ ],
+ OOMObserver: [
+ pytest.mark.filterwarnings(
+ r'ignore:The oom observer only works on CUDA devices, but the model is on cpu:UserWarning')
+ ],
MLPerfCallback: [pytest.mark.skipif(not _MLPERF_INSTALLED, reason='MLPerf is optional')],
WandBLogger: [
pytest.mark.filterwarnings(r'ignore:unclosed file:ResourceWarning'),
@@ -145,7 +168,7 @@
ImageVisualizer: [pytest.mark.skipif(not _WANDB_INSTALLED, reason='Wandb is optional')],
MLFlowLogger: [pytest.mark.skipif(not _MLFLOW_INSTALLED, reason='mlflow is optional'),],
SystemMetricsMonitor: [pytest.mark.skipif(not _PYNMVL_INSTALLED, reason='pynmvl is optional'),],
- HealthChecker: [pytest.mark.filterwarnings('ignore:.*HealthChecker is deprecated.*')],
+ NeptuneLogger: [pytest.mark.skipif(not _NEPTUNE_INSTALLED, reason='neptune is optional'),],
}
diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py
index 695be08c55..f0ddbe43cc 100644
--- a/tests/callbacks/test_callbacks.py
+++ b/tests/callbacks/test_callbacks.py
@@ -43,12 +43,14 @@ class TestCallbacks:
def setup_class(cls):
pytest.importorskip('wandb', reason='WandB is optional.')
+ @pytest.mark.filterwarnings('ignore::UserWarning')
def test_callback_is_constructable(self, cb_cls: Type[Callback]):
cb_kwargs = get_cb_kwargs(cb_cls)
cb = cb_cls(**cb_kwargs)
assert isinstance(cb_cls, type)
assert isinstance(cb, cb_cls)
+ @pytest.mark.filterwarnings('ignore::UserWarning')
def test_multiple_fit_start_and_end(self, cb_cls: Type[Callback], dummy_state: State):
"""Test that callbacks do not crash when Event.FIT_START and Event.FIT_END is called multiple times."""
cb_kwargs = get_cb_kwargs(cb_cls)
@@ -69,6 +71,7 @@ def test_multiple_fit_start_and_end(self, cb_cls: Type[Callback], dummy_state: S
engine.run_event(Event.FIT_START)
engine.run_event(Event.FIT_END)
+ @pytest.mark.filterwarnings('ignore::UserWarning')
def test_idempotent_close(self, cb_cls: Type[Callback], dummy_state: State):
"""Test that callbacks do not crash when .close() and .post_close() are called multiple times."""
cb_kwargs = get_cb_kwargs(cb_cls)
@@ -85,6 +88,7 @@ def test_idempotent_close(self, cb_cls: Type[Callback], dummy_state: State):
engine.close()
engine.close()
+ @pytest.mark.filterwarnings('ignore::UserWarning')
def test_multiple_init_and_close(self, cb_cls: Type[Callback], dummy_state: State):
"""Test that callbacks do not crash when INIT/.close()/.post_close() are called multiple times in that order."""
cb_kwargs = get_cb_kwargs(cb_cls)
@@ -136,6 +140,7 @@ def _get_trainer(self, cb: Callback, device_train_microbatch_size: int):
torch_prof_memory_filename=None),
)
+ @pytest.mark.filterwarnings('ignore::UserWarning')
def test_trains(self, cb_cls: Type[Callback], device_train_microbatch_size: int, _remote: bool):
del _remote # unused. `_remote` must be passed through to parameterize the test markers.
cb_kwargs = get_cb_kwargs(cb_cls)
@@ -143,6 +148,7 @@ def test_trains(self, cb_cls: Type[Callback], device_train_microbatch_size: int,
trainer = self._get_trainer(cb, device_train_microbatch_size)
trainer.fit()
+ @pytest.mark.filterwarnings('ignore::UserWarning')
def test_trains_multiple_calls(self, cb_cls: Type[Callback], device_train_microbatch_size: int, _remote: bool):
"""
Tests that training with multiple fits complete.
diff --git a/tests/callbacks/test_checkpoint_saver.py b/tests/callbacks/test_checkpoint_saver.py
new file mode 100644
index 0000000000..67654b9b17
--- /dev/null
+++ b/tests/callbacks/test_checkpoint_saver.py
@@ -0,0 +1,45 @@
+# Copyright 2022 MosaicML Composer authors
+# SPDX-License-Identifier: Apache-2.0
+
+from composer.callbacks import CheckpointSaver
+from composer.core import Timestamp
+
+
+def test_stateful_checkpoint_saver():
+ checkpoint_saver = CheckpointSaver()
+ assert not checkpoint_saver.all_saved_checkpoints_to_timestamp
+
+ # empty state dict
+ empty_state_dict = checkpoint_saver.state_dict()
+ assert 'all_saved_checkpoints_to_timestamp' in empty_state_dict
+ assert len(empty_state_dict['all_saved_checkpoints_to_timestamp']) == 0
+
+ # backwards compatibility; empty state dict should not raise
+ checkpoint_saver.load_state_dict({})
+ assert not checkpoint_saver.all_saved_checkpoints_to_timestamp
+
+ # add a checkpoint and confirm it can save and load
+ checkpoint_saver.all_saved_checkpoints_to_timestamp = {
+ 'foobar/example-checkpoint.pt': Timestamp(epoch=1, batch=2),
+ }
+ new_state_dict = checkpoint_saver.state_dict()
+ assert 'all_saved_checkpoints_to_timestamp' in new_state_dict
+ assert len(new_state_dict['all_saved_checkpoints_to_timestamp']) == 1
+ checkpoint, ts = new_state_dict['all_saved_checkpoints_to_timestamp'][0]
+ assert checkpoint == 'foobar/example-checkpoint.pt'
+ assert isinstance(ts, dict)
+ assert ts['epoch'] == 1
+ assert ts['batch'] == 2
+ assert ts['sample'] == 0
+
+ # load works again if we clear the dict
+ checkpoint_saver.all_saved_checkpoints_to_timestamp = {}
+ checkpoint_saver.load_state_dict(new_state_dict)
+ assert checkpoint_saver.all_saved_checkpoints_to_timestamp
+ assert len(checkpoint_saver.all_saved_checkpoints_to_timestamp) == 1
+ assert 'foobar/example-checkpoint.pt' in checkpoint_saver.all_saved_checkpoints_to_timestamp
+ ts = checkpoint_saver.all_saved_checkpoints_to_timestamp['foobar/example-checkpoint.pt']
+ assert isinstance(ts, Timestamp)
+ assert ts.epoch == 1
+ assert ts.batch == 2
+ assert ts.sample == 0
diff --git a/tests/callbacks/test_generate.py b/tests/callbacks/test_generate.py
index a848071dff..c9247ce616 100644
--- a/tests/callbacks/test_generate.py
+++ b/tests/callbacks/test_generate.py
@@ -7,7 +7,6 @@
import pytest
import torch
-from packaging import version
from composer.callbacks import Generate
from composer.core import Event
@@ -24,8 +23,6 @@
class TestGenerate():
def _check_test_params(self, device, world_size, use_fsdp) -> None:
- if use_fsdp and version.parse(torch.__version__) < version.parse('1.13.0'):
- pytest.skip('FSDP requires torch >= 1.13.0')
if device == 'cpu' and use_fsdp:
pytest.skip('FSDP is not supported on CPU.')
if world_size == 1 and use_fsdp:
diff --git a/tests/callbacks/test_health_checker.py b/tests/callbacks/test_health_checker.py
deleted file mode 100644
index 5638699ca9..0000000000
--- a/tests/callbacks/test_health_checker.py
+++ /dev/null
@@ -1,112 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-import datetime
-from unittest.mock import MagicMock, patch
-
-import pytest
-
-from composer import Timestamp
-from composer.callbacks import HealthChecker
-from composer.callbacks.health_checker import GPUUtilization
-from composer.utils import dist
-from tests.common import world_size
-
-pynvml = pytest.importorskip('pynvml')
-pytest.importorskip('slack_sdk')
-
-
-class MockUtil:
-
- def __init__(self, util):
- self.gpu = util
-
-
-@pytest.mark.gpu
-@world_size(1, 2)
-@pytest.mark.filterwarnings('ignore:.*HealthChecker is deprecated.*')
-def test_gpu_utilization(world_size):
- assert HealthChecker._is_available()
-
- gpu_utilization_values = [
- MockUtil(100),
- MockUtil(10),
- MockUtil(100),
- MockUtil(100),
- MockUtil(100),
- MockUtil(100),
- ]
-
- with patch.multiple(pynvml,
- nvmlDeviceGetUtilizationRates=MagicMock(side_effect=gpu_utilization_values),
- nvmlDeviceGetCount=MagicMock(return_value=world_size)):
-
- gpu_utilization = GPUUtilization()
- gpu_utilization.sample()
- gpu_utilization.sample()
- gpu_utilization.sample()
- _, alert = gpu_utilization.check()
-
- should_alert = dist.get_local_rank() == 0 and world_size > 1
- assert alert == should_alert
-
-
-@pytest.mark.gpu
-@world_size(1, 2)
-@pytest.mark.filterwarnings('ignore:.*HealthChecker is deprecated.*')
-def test_health_checker(world_size):
-
- state = MagicMock()
- state.run_name = 'pytest-mock-run-kwei73'
- logger = MagicMock()
-
- health_checker = HealthChecker(
- sample_freq=1,
- window_size=3,
- wait=0,
- )
-
- gpu_utilization_values = [
- MockUtil(100),
- MockUtil(10),
- MockUtil(100),
- MockUtil(100),
- MockUtil(100),
- MockUtil(100),
- ]
-
- with patch.multiple(pynvml,
- nvmlDeviceGetUtilizationRates=MagicMock(side_effect=gpu_utilization_values),
- nvmlDeviceGetCount=MagicMock(return_value=world_size)):
-
- # collect data and checker
- for seconds in [1, 2, 3]:
- state.timestamp = Timestamp(total_wct=datetime.timedelta(seconds=seconds))
- health_checker.after_train_batch(state, logger)
-
- should_alert = dist.get_local_rank() == 0 and world_size > 1
- assert health_checker.metrics[0].alerted == should_alert
-
-
-@pytest.mark.filterwarnings('ignore:.*HealthChecker is deprecated.*')
-def test_health_checker_sampling():
- timestamp = Timestamp(total_wct=datetime.timedelta(seconds=0))
-
- health_checker = HealthChecker(
- sample_freq=1,
- window_size=5,
- wait=10,
- )
-
- config = [
- (5, False), # before wait
- (11, True),
- (11.5, False), # below sample frequency
- (12, True),
- (20, True),
- (11, False), # no time travel
- ]
-
- for seconds, is_sample in config:
- timestamp = Timestamp(total_wct=datetime.timedelta(seconds=seconds))
- assert health_checker._sample(timestamp) == is_sample
diff --git a/tests/callbacks/test_inference.py b/tests/callbacks/test_inference.py
index 960aec9a04..bef07c081c 100644
--- a/tests/callbacks/test_inference.py
+++ b/tests/callbacks/test_inference.py
@@ -13,9 +13,9 @@
from torch.utils.data import DataLoader
from composer.callbacks import ExportForInferenceCallback, export_for_inference
-from composer.models import composer_resnet
from composer.trainer import Trainer
from tests.common.datasets import RandomImageDataset
+from tests.common.models import composer_resnet
@pytest.mark.parametrize(
diff --git a/tests/callbacks/test_loggers_across_callbacks.py b/tests/callbacks/test_loggers_across_callbacks.py
index 92363e7aa5..1c58babf0b 100644
--- a/tests/callbacks/test_loggers_across_callbacks.py
+++ b/tests/callbacks/test_loggers_across_callbacks.py
@@ -15,6 +15,7 @@
@pytest.mark.parametrize('logger_cls', get_cbs_and_marks(loggers=True))
@pytest.mark.parametrize('callback_cls', get_cbs_and_marks(callbacks=True))
+@pytest.mark.filterwarnings('ignore::UserWarning')
def test_loggers_on_callbacks(logger_cls: Type[LoggerDestination], callback_cls: Type[Callback]):
if logger_cls in [ProgressBarLogger, ConsoleLogger, SlackLogger]:
pytest.skip()
diff --git a/tests/callbacks/test_memory_monitor.py b/tests/callbacks/test_memory_monitor.py
index f40a04eeb3..f2badc638c 100644
--- a/tests/callbacks/test_memory_monitor.py
+++ b/tests/callbacks/test_memory_monitor.py
@@ -7,13 +7,10 @@
from composer.callbacks import MemoryMonitor
from composer.loggers import InMemoryLogger
from composer.trainer import Trainer
-from tests.common import RandomClassificationDataset, SimpleModel, device
+from tests.common import RandomClassificationDataset, SimpleModel
-@device('cpu', 'gpu')
-def test_memory_monitor_warnings_on_cpu_models(device: str):
- # Error if the user sets device=cpu even when cuda is available
- del device # unused. always using cpu
+def test_memory_monitor_warnings_on_cpu_models():
with pytest.warns(UserWarning, match='The memory monitor only works on CUDA devices'):
Trainer(
model=SimpleModel(),
diff --git a/tests/callbacks/test_memory_snapshot.py b/tests/callbacks/test_memory_snapshot.py
new file mode 100644
index 0000000000..0bafbcb1c1
--- /dev/null
+++ b/tests/callbacks/test_memory_snapshot.py
@@ -0,0 +1,62 @@
+# Copyright 2022 MosaicML Composer authors
+# SPDX-License-Identifier: Apache-2.0
+
+import pathlib
+
+import pytest
+import torch
+from packaging import version
+from torch.utils.data import DataLoader
+
+from composer import State, Trainer
+from composer.callbacks import MemorySnapshot
+from composer.loggers import LoggerDestination
+from composer.trainer import Trainer
+from tests.common import RandomClassificationDataset, SimpleModel
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'),
+ reason='OOM Observer requires PyTorch 2.1 or higher')
+def test_memory_snapshot_warnings_on_cpu_models():
+ with pytest.warns(UserWarning):
+ Trainer(
+ model=SimpleModel(),
+ callbacks=MemorySnapshot(),
+ device='cpu',
+ train_dataloader=DataLoader(RandomClassificationDataset()),
+ max_duration='1ba',
+ )
+
+
+class FileUploaderTracker(LoggerDestination):
+
+ def __init__(self) -> None:
+ self.uploaded_files = []
+
+ def upload_file(self, state: State, remote_file_name: str, file_path: pathlib.Path, *, overwrite: bool):
+ del state, overwrite # unused
+ self.uploaded_files.append((remote_file_name, file_path))
+
+
+@pytest.mark.gpu
+@pytest.mark.parametrize('interval', ['1ba'])
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'),
+ reason='OOM Observer requires PyTorch 2.1 or higher')
+def test_memory_snapshot(interval: str):
+ # Construct the callbacks
+ skip_batches = 0
+ memory_snapshot = MemorySnapshot(skip_batches=skip_batches, interval=interval)
+ simple_model = SimpleModel()
+ file_tracker_destination = FileUploaderTracker()
+
+ # Construct the trainer and train
+ trainer = Trainer(
+ model=simple_model,
+ loggers=file_tracker_destination,
+ callbacks=memory_snapshot,
+ train_dataloader=DataLoader(RandomClassificationDataset()),
+ max_duration='2ba',
+ )
+ trainer.fit()
+ assert len(file_tracker_destination.uploaded_files) == 2
+ trainer.close()
diff --git a/tests/callbacks/test_oom_observer.py b/tests/callbacks/test_oom_observer.py
new file mode 100644
index 0000000000..60323b00c0
--- /dev/null
+++ b/tests/callbacks/test_oom_observer.py
@@ -0,0 +1,88 @@
+# Copyright 2022 MosaicML Composer authors
+# SPDX-License-Identifier: Apache-2.0
+
+import pathlib
+
+import pytest
+import torch
+from packaging import version
+from torch.utils.data import DataLoader
+
+from composer import State, Trainer
+from composer.callbacks import MemorySnapshot, OOMObserver
+from composer.loggers import LoggerDestination
+from composer.trainer import Trainer
+from tests.common import RandomClassificationDataset, SimpleModel
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'),
+ reason='OOM Observer requires PyTorch 2.1 or higher')
+def test_oom_observer_warnings_on_cpu_models():
+ ob = OOMObserver()
+ with pytest.warns(UserWarning):
+ Trainer(
+ model=SimpleModel(),
+ callbacks=ob,
+ train_dataloader=DataLoader(RandomClassificationDataset()),
+ max_duration='1ba',
+ device='cpu',
+ )
+ assert ob._enabled is False
+
+
+class FileUploaderTracker(LoggerDestination):
+
+ def __init__(self) -> None:
+ self.uploaded_files = []
+
+ def upload_file(self, state: State, remote_file_name: str, file_path: pathlib.Path, *, overwrite: bool):
+ del state, overwrite # unused
+ self.uploaded_files.append((remote_file_name, file_path))
+
+
+@pytest.mark.gpu
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'),
+ reason='OOM Observer requires PyTorch 2.1 or higher')
+def test_oom_observer():
+ # Construct the callbacks
+ oom_observer = OOMObserver()
+ simple_model = SimpleModel()
+ file_tracker_destination = FileUploaderTracker()
+
+ with pytest.raises(torch.cuda.OutOfMemoryError):
+ trainer = Trainer(
+ model=simple_model,
+ loggers=file_tracker_destination,
+ callbacks=oom_observer,
+ train_dataloader=DataLoader(RandomClassificationDataset()),
+ max_duration='2ba',
+ )
+
+ # trigger OOM
+ torch.empty(1024 * 1024 * 1024 * 1024, device='cuda')
+
+ trainer.fit()
+
+ assert len(file_tracker_destination.uploaded_files) == 5
+
+
+@pytest.mark.gpu
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'),
+ reason='OOM Observer requires PyTorch 2.1 or higher')
+def test_oom_observer_with_memory_snapshot():
+ # Construct the callbacks
+ oom_observer = OOMObserver()
+ memory_snapshot = MemorySnapshot(skip_batches=0, interval='1ba')
+ simple_model = SimpleModel()
+ file_tracker_destination = FileUploaderTracker()
+
+ trainer = Trainer(
+ model=simple_model,
+ loggers=file_tracker_destination,
+ callbacks=[oom_observer, memory_snapshot],
+ train_dataloader=DataLoader(RandomClassificationDataset()),
+ max_duration='2ba',
+ )
+
+ trainer.fit()
+ assert len(file_tracker_destination.uploaded_files) == 2
diff --git a/tests/callbacks/test_optimizer_monitor.py b/tests/callbacks/test_optimizer_monitor.py
index 226a38c119..02ee0586fb 100644
--- a/tests/callbacks/test_optimizer_monitor.py
+++ b/tests/callbacks/test_optimizer_monitor.py
@@ -11,7 +11,7 @@
from composer.models import HuggingFaceModel
from composer.optim import DecoupledAdamW
from composer.trainer import Trainer
-from composer.utils import dist, using_torch_2
+from composer.utils import dist
from tests.common import device, world_size
from tests.common.datasets import RandomClassificationDataset, RandomTextLMDataset
from tests.common.models import SimpleModel
@@ -57,16 +57,13 @@ def test_optimizer_monitor(log_optimizer_metrics: bool, batch_log_interval: int)
reason='requires PyTorch 1.13 or higher')
@pytest.mark.parametrize('use_orig_params', [True, False])
def test_fsdp_optimizer_monitor(device, world_size, use_orig_params):
- if use_orig_params and not using_torch_2():
- pytest.skip('use_orig_params was introduced in pytorch 2.0')
-
# Construct the callback
grad_monitor = OptimizerMonitor(log_optimizer_metrics=True)
in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger
model = SimpleModel(num_classes=100, num_features=100, num_hidden=100)
for module in model.modules():
if len(list(module.parameters())) > 0:
- module._fsdp_wrap = True
+ module._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
dataset = RandomClassificationDataset(num_classes=100, shape=(100, 1, 1))
# Construct the trainer and train
trainer = Trainer(model=model,
@@ -91,12 +88,11 @@ def test_fsdp_optimizer_monitor(device, world_size, use_orig_params):
# Count the logged steps
grad_norm_calls = len(in_memory_logger.data['l2_norm/grad/global'])
layer_norm_calls = [len(calls) for (k, calls) in in_memory_logger.data.items() if 'l2_norm/grad' in k]
- suffix = ('._flat_param' if using_torch_2() else '.flat_param') if not use_orig_params else '.weight'
- infix = '' if using_torch_2() else '._fpw_module'
+ suffix = '._flat_param' if not use_orig_params else '.weight'
test_keys = [
- f'l2_norm/grad/module._fsdp_wrapped_module{infix}.4._fsdp_wrapped_module',
- f'l2_norm/moment/module._fsdp_wrapped_module{infix}.4._fsdp_wrapped_module',
- f'l2_norm/update/module._fsdp_wrapped_module{infix}.4._fsdp_wrapped_module',
+ f'l2_norm/grad/module._fsdp_wrapped_module.4._fsdp_wrapped_module',
+ f'l2_norm/moment/module._fsdp_wrapped_module.4._fsdp_wrapped_module',
+ f'l2_norm/update/module._fsdp_wrapped_module.4._fsdp_wrapped_module',
]
test_keys = [key + suffix for key in test_keys]
for key in test_keys:
@@ -110,12 +106,8 @@ def test_fsdp_optimizer_monitor(device, world_size, use_orig_params):
@device('gpu')
@world_size(1, 2)
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='requires PyTorch 1.13 or higher')
@pytest.mark.parametrize('use_orig_params', [True, False])
def test_fsdp_optimizer_monitor_transformer(device, world_size, tiny_gpt2_model, tiny_gpt2_tokenizer, use_orig_params):
- if use_orig_params and not using_torch_2():
- pytest.skip('use_orig_params was introduced in pytorch 2.0')
transformers = pytest.importorskip('transformers')
# Construct the callback
grad_monitor = OptimizerMonitor(log_optimizer_metrics=True)
@@ -164,11 +156,9 @@ def test_fsdp_optimizer_monitor_transformer(device, world_size, tiny_gpt2_model,
layer_norm_calls = [len(calls) for (k, calls) in in_memory_logger.data.items() if 'l2_norm/grad' in k]
# an incomplete list of expected keys
if not use_orig_params:
- suffix = '._flat_param' if using_torch_2() else '.flat_param'
- infix = '' if using_torch_2() else '._fpw_module'
test_keys = [
- f'l2_norm/grad/model._fsdp_wrapped_module{infix}.transformer.h.1._fsdp_wrapped_module{suffix}',
- f'l2_norm/update/model._fsdp_wrapped_module{infix}.transformer.h.1._fsdp_wrapped_module{suffix}',
+ f'l2_norm/grad/model._fsdp_wrapped_module.transformer.h.1._fsdp_wrapped_module._flat_param',
+ f'l2_norm/update/model._fsdp_wrapped_module.transformer.h.1._fsdp_wrapped_module._flat_param',
]
else:
test_keys = [
diff --git a/tests/callbacks/test_speed_monitor.py b/tests/callbacks/test_speed_monitor.py
index f880a7c370..36c30a87f6 100644
--- a/tests/callbacks/test_speed_monitor.py
+++ b/tests/callbacks/test_speed_monitor.py
@@ -34,7 +34,7 @@ def test_speed_monitor(flops_per_batch: bool):
model = SimpleModel()
if flops_per_batch:
- model.flops_per_batch = lambda batch: len(batch) * 100.0
+ model.flops_per_batch = lambda batch: len(batch) * 100.0 # pyright: ignore[reportGeneralTypeIssues]
# Construct the trainer and train
trainer = Trainer(
diff --git a/tests/common/__init__.py b/tests/common/__init__.py
index be2a508860..bcc9903e61 100644
--- a/tests/common/__init__.py
+++ b/tests/common/__init__.py
@@ -12,7 +12,7 @@
from tests.common.markers import device, world_size
from tests.common.models import (ConvModel, EmbeddedWeightTiedModel, EmptyModel, SimpleConvModel, SimpleModel,
SimpleModelWithDropout, SimpleTransformerClassifier, SimpleTransformerMaskedLM,
- SimpleWeightTiedModel, ZeroModel)
+ SimpleWeightTiedModel, ZeroModel, composer_resnet)
from tests.common.state import assert_state_equivalent
@@ -46,4 +46,5 @@ def get_module_subclasses(module: types.ModuleType, cls: Type) -> List[Type]:
'ParityDataset',
'SimpleDataset',
'InfiniteClassificationDataset',
+ 'composer_resnet',
]
diff --git a/tests/common/models.py b/tests/common/models.py
index cac3769b38..d8bf2994d4 100644
--- a/tests/common/models.py
+++ b/tests/common/models.py
@@ -4,15 +4,21 @@
"""Contains commonly used models that are shared across the test suite."""
import copy
from functools import partial
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import pytest
import torch
from torchmetrics import Metric, MetricCollection
+from torchmetrics.classification import MulticlassAccuracy
+from torchvision.models import resnet
+from composer.loss import loss_registry
from composer.metrics import CrossEntropy, MIoU
from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy
-from composer.models import ComposerClassifier, HuggingFaceModel
+from composer.models import ComposerClassifier, HuggingFaceModel, Initializer
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
class EmptyModel(ComposerClassifier):
@@ -74,7 +80,7 @@ def __init__(
fc2,
torch.nn.Softmax(dim=-1),
)
- net.param_init_fn = self.param_init_fn
+ net.param_init_fn = self.param_init_fn # pyright: ignore[reportGeneralTypeIssues]
super().__init__(module=net, num_classes=num_classes)
# Important: It is crucial that the FC layers are bound to `self`
@@ -90,7 +96,7 @@ def param_init_fn(self, module):
if isinstance(module, torch.nn.Linear):
init_fn(module.weight)
- if module.bias is not None:
+ if module.bias is not None: # pyright: ignore[reportUnnecessaryComparison]
torch.nn.init.zeros_(module.bias)
@@ -131,7 +137,7 @@ def __init__(self, num_features: int = 1, device: str = 'cpu') -> None:
self.mlp = mlp
self.net = net
- self.net.param_init_fn = self.param_init_fn
+ self.net.param_init_fn = self.param_init_fn # pyright: ignore[reportGeneralTypeIssues]
self.mlp.fc1.weight = self.mlp.fc2.weight
@@ -140,7 +146,7 @@ def param_init_fn(self, module):
if isinstance(module, torch.nn.Linear):
init_fn(module.weight)
- if module.bias is not None:
+ if module.bias is not None: # pyright: ignore[reportUnnecessaryComparison]
torch.nn.init.zeros_(module.bias)
@@ -166,7 +172,7 @@ def __init__(self, num_features: int = 1, device: str = 'cpu') -> None:
super().__init__(module=net, num_classes=num_features)
- self.module.param_init_fn = self.param_init_fn
+ self.module.param_init_fn = self.param_init_fn # pyright: ignore[reportGeneralTypeIssues]
self.net1 = net1
self.net2 = net2
@@ -178,7 +184,7 @@ def param_init_fn(self, module):
if isinstance(module, torch.nn.Linear):
init_fn(module.weight)
- if module.bias is not None:
+ if module.bias is not None: # pyright: ignore[reportUnnecessaryComparison]
torch.nn.init.zeros_(module.bias)
@@ -437,107 +443,224 @@ def forward(self, batch: Tuple[torch.Tensor, Any]) -> torch.Tensor:
return outputs
+def composer_resnet(
+ model_name: str,
+ num_classes: int = 1000,
+ weights: Optional[str] = None,
+ groups: int = 1,
+ width_per_group: int = 64,
+ initializers: Optional[List[Initializer]] = None,
+ loss_name: str = 'soft_cross_entropy',
+) -> ComposerClassifier:
+ """Helper function to create a :class:`.ComposerClassifier` with a torchvision ResNet model.
+ From `Deep Residual Learning for Image Recognition `_ (He et al, 2015).
+ Args:
+ model_name (str): Name of the ResNet model instance. Either [``"resnet18"``, ``"resnet34"``, ``"resnet50"``, ``"resnet101"``,
+ ``"resnet152"``].
+ num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``1000``.
+ weights (str, optional): If provided, pretrained weights can be specified, such as with ``IMAGENET1K_V2``. Default: ``None``.
+ groups (int, optional): Number of filter groups for the 3x3 convolution layer in bottleneck blocks. Default: ``1``.
+ width_per_group (int, optional): Initial width for each convolution group. Width doubles after each stage.
+ Default: ``64``.
+ initializers (List[Initializer], optional): Initializers for the model. ``None`` for no initialization.
+ Default: ``None``.
+ loss_name (str, optional): Loss function to use. E.g. 'soft_cross_entropy' or
+ 'binary_cross_entropy_with_logits'. Loss function must be in
+ :mod:`~composer.loss.loss`. Default: ``'soft_cross_entropy'``".
+ Returns:
+ ComposerModel: instance of :class:`.ComposerClassifier` with a torchvision ResNet model.
+ """
+ valid_model_names = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
+ if model_name not in valid_model_names:
+ raise ValueError(f'model_name must be one of {valid_model_names} instead of {model_name}.')
+
+ if loss_name not in loss_registry.keys():
+ raise ValueError(f'Unrecognized loss function: {loss_name}. Please ensure the '
+ 'specified loss function is present in composer.loss.loss.py')
+
+ if initializers is None:
+ initializers = []
+
+ # Instantiate model
+ model_fn = getattr(resnet, model_name)
+ model = model_fn(weights=weights, num_classes=num_classes, groups=groups, width_per_group=width_per_group)
+
+ # Grab loss function from loss registry
+ loss_fn = loss_registry[loss_name]
+
+ # Create metrics for train and validation
+ train_metrics = MulticlassAccuracy(num_classes=num_classes, average='micro')
+ val_metrics = MetricCollection([CrossEntropy(), MulticlassAccuracy(num_classes=num_classes, average='micro')])
+
+ # Apply Initializers to model
+ for initializer in initializers:
+ initializer = Initializer(initializer)
+ model.apply(initializer.get_initializer())
+
+ composer_model = ComposerClassifier(model, train_metrics=train_metrics, val_metrics=val_metrics, loss_fn=loss_fn)
+ return composer_model
+
+
# Note: These methods are an alternative to the tiny_bert fixtures in fixtures.py.
# Fixtures cannot be used natively as parametrized inputs, which we require when
# we wish to run a test across multiple models, one of which is a HuggingFace model.
# As a workaround, we inject objects into the PyTest namespace. Tests should not directly
# use pytest.{var}, but instead should import and use these helper copy methods so the
# objects in the PyTest namespace do not change.
-def configure_tiny_bert_model():
+def configure_tiny_bert_model() -> 'PreTrainedModel':
try:
+ from transformers import PreTrainedModel
+ assert isinstance(pytest.tiny_bert_model, PreTrainedModel)
return copy.deepcopy(pytest.tiny_bert_model)
except AttributeError:
pytest.skip('Composer installed without NLP support')
-def configure_tiny_bert_tokenizer():
+def configure_tiny_bert_tokenizer() -> Union['PreTrainedTokenizer', 'PreTrainedTokenizerFast']:
try:
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
+ assert isinstance(pytest.tiny_bert_tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast))
return copy.deepcopy(pytest.tiny_bert_tokenizer)
except AttributeError:
pytest.skip('Composer installed without NLP support')
-def configure_tiny_bert_config():
+def configure_tiny_bert_config() -> 'PretrainedConfig':
try:
+ from transformers import PretrainedConfig
+ assert isinstance(pytest.tiny_bert_config, PretrainedConfig)
return copy.deepcopy(pytest.tiny_bert_config)
except AttributeError:
pytest.skip('Composer installed without NLP support')
-def configure_tiny_bert_hf_model(use_logits=True):
+def configure_tiny_bert_hf_model(use_logits: bool = True) -> HuggingFaceModel:
return HuggingFaceModel(configure_tiny_bert_model(), configure_tiny_bert_tokenizer(), use_logits)
-def configure_tiny_deberta_model():
+def configure_tiny_deberta_model() -> 'PreTrainedModel':
try:
+ from transformers import PreTrainedModel
+ assert isinstance(pytest.tiny_deberta_model, PreTrainedModel)
return copy.deepcopy(pytest.tiny_deberta_model)
except AttributeError:
pytest.skip('Composer installed without NLP support')
-def configure_tiny_deberta_tokenizer():
+def configure_tiny_deberta_tokenizer() -> Union['PreTrainedTokenizer', 'PreTrainedTokenizerFast']:
try:
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
+ assert isinstance(pytest.tiny_deberta_tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast))
return copy.deepcopy(pytest.tiny_deberta_tokenizer)
except AttributeError:
pytest.skip('Composer installed without NLP support')
-def configure_tiny_deberta_config():
+def configure_tiny_deberta_config() -> 'PretrainedConfig':
try:
+ from transformers import PretrainedConfig
+ assert isinstance(pytest.tiny_deberta_config, PretrainedConfig)
return copy.deepcopy(pytest.tiny_deberta_config)
except AttributeError:
pytest.skip('Composer installed without NLP support')
-def configure_tiny_deberta_hf_model(use_logits=True):
- return HuggingFaceModel(configure_tiny_deberta_model(), configure_tiny_deberta_tokenizer(), use_logits)
+def configure_tiny_deberta_hf_model(use_logits: bool = True) -> HuggingFaceModel:
+ return HuggingFaceModel(
+ configure_tiny_deberta_model(),
+ configure_tiny_deberta_tokenizer(),
+ use_logits,
+ )
-def configure_tiny_gpt2_model():
+def configure_tiny_gpt2_model() -> 'PreTrainedModel':
try:
+ from transformers import PreTrainedModel
+ assert isinstance(pytest.tiny_gpt2_model, PreTrainedModel)
return copy.deepcopy(pytest.tiny_gpt2_model)
except AttributeError:
pytest.skip('Composer installed without NLP support')
-def configure_tiny_gpt2_tokenizer():
+def configure_tiny_gpt2_tokenizer() -> Union['PreTrainedTokenizer', 'PreTrainedTokenizerFast']:
try:
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
+ assert isinstance(pytest.tiny_gpt2_tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast))
return copy.deepcopy(pytest.tiny_gpt2_tokenizer)
except AttributeError:
pytest.skip('Composer installed without NLP support')
-def configure_tiny_gpt2_config():
+def configure_tiny_gpt2_config() -> 'PretrainedConfig':
try:
+ from transformers import PretrainedConfig
+ assert isinstance(pytest.tiny_gpt2_config, PretrainedConfig)
return copy.deepcopy(pytest.tiny_gpt2_config)
except AttributeError:
pytest.skip('Composer installed without NLP support')
-def configure_tiny_gpt2_hf_model(use_logits=True):
+def configure_tiny_gpt2_hf_model(use_logits: bool = True) -> HuggingFaceModel:
return HuggingFaceModel(configure_tiny_gpt2_model(), configure_tiny_gpt2_tokenizer(), use_logits)
-def configure_tiny_t5_model():
+def configure_tiny_t5_model() -> 'PreTrainedModel':
try:
+ from transformers import PreTrainedModel
+ assert isinstance(pytest.tiny_t5_model, PreTrainedModel)
return copy.deepcopy(pytest.tiny_t5_model)
except AttributeError:
pytest.skip('Composer installed without NLP support')
-def configure_tiny_t5_tokenizer():
+def configure_tiny_t5_tokenizer() -> Union['PreTrainedTokenizer', 'PreTrainedTokenizerFast']:
try:
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
+ assert isinstance(pytest.tiny_t5_tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast))
return copy.deepcopy(pytest.tiny_t5_tokenizer)
except AttributeError:
pytest.skip('Composer installed without NLP support')
-def configure_tiny_t5_config():
+def configure_tiny_t5_config() -> 'PretrainedConfig':
try:
+ from transformers import PretrainedConfig
+ assert isinstance(pytest.tiny_t5_config, PretrainedConfig)
return copy.deepcopy(pytest.tiny_t5_config)
except AttributeError:
pytest.skip('Composer installed without NLP support')
-def configure_tiny_t5_hf_model(use_logits=True):
+def configure_tiny_t5_hf_model(use_logits: bool = True) -> HuggingFaceModel:
return HuggingFaceModel(configure_tiny_t5_model(), configure_tiny_t5_tokenizer(), use_logits)
+
+
+def configure_tiny_mistral_model() -> 'PreTrainedModel':
+ try:
+ from transformers import PreTrainedModel
+ assert isinstance(pytest.tiny_mistral_model, PreTrainedModel)
+ return copy.deepcopy(pytest.tiny_mistral_model)
+ except AttributeError:
+ pytest.skip('Composer installed without NLP support')
+
+
+def configure_tiny_mistral_tokenizer() -> Union['PreTrainedTokenizer', 'PreTrainedTokenizerFast']:
+ try:
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
+ assert isinstance(pytest.tiny_mistral_tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast))
+ return copy.deepcopy(pytest.tiny_mistral_tokenizer)
+ except AttributeError:
+ pytest.skip('Composer installed without NLP support')
+
+
+def configure_tiny_mistral_config() -> 'PretrainedConfig':
+ try:
+ from transformers import PretrainedConfig
+ assert isinstance(pytest.tiny_mistral_config, PretrainedConfig)
+ return copy.deepcopy(pytest.tiny_mistral_config)
+ except AttributeError:
+ pytest.skip('Composer installed without NLP support')
+
+
+def configure_tiny_mistral_hf_model(use_logits: bool = True) -> HuggingFaceModel:
+ return HuggingFaceModel(configure_tiny_mistral_model(), configure_tiny_mistral_tokenizer(), use_logits)
diff --git a/tests/conftest.py b/tests/conftest.py
index bcd063d9c7..e327730d42 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -111,21 +111,27 @@ def pytest_configure():
if TRANSFORMERS_INSTALLED:
from tests.fixtures.fixtures import (tiny_bert_config_helper, tiny_bert_model_helper,
tiny_bert_tokenizer_helper, tiny_gpt2_config_helper,
- tiny_gpt2_model_helper, tiny_gpt2_tokenizer_helper, tiny_opt_config_helper,
- tiny_opt_model_helper, tiny_opt_tokenizer_helper, tiny_t5_config_helper,
- tiny_t5_model_helper, tiny_t5_tokenizer_helper)
+ tiny_gpt2_model_helper, tiny_gpt2_tokenizer_helper,
+ tiny_llama_tokenizer_helper, tiny_mistral_config_helper,
+ tiny_mistral_model_helper, tiny_mistral_tokenizer_helper,
+ tiny_opt_config_helper, tiny_opt_model_helper, tiny_opt_tokenizer_helper,
+ tiny_t5_config_helper, tiny_t5_model_helper, tiny_t5_tokenizer_helper)
pytest.tiny_bert_config = tiny_bert_config_helper() # type: ignore
pytest.tiny_bert_model = tiny_bert_model_helper(pytest.tiny_bert_config) # type: ignore
pytest.tiny_bert_tokenizer = tiny_bert_tokenizer_helper() # type: ignore
pytest.tiny_gpt2_config = tiny_gpt2_config_helper() # type: ignore
pytest.tiny_gpt2_model = tiny_gpt2_model_helper(pytest.tiny_gpt2_config) # type: ignore
pytest.tiny_gpt2_tokenizer = tiny_gpt2_tokenizer_helper() # type: ignore
+ pytest.tiny_llama_tokenizer = tiny_llama_tokenizer_helper() # type: ignore
pytest.tiny_opt_config = tiny_opt_config_helper() # type: ignore
pytest.tiny_opt_model = tiny_opt_model_helper(pytest.tiny_opt_config) # type: ignore
pytest.tiny_opt_tokenizer = tiny_opt_tokenizer_helper() # type: ignore
pytest.tiny_t5_config = tiny_t5_config_helper() # type: ignore
pytest.tiny_t5_model = tiny_t5_model_helper(pytest.tiny_t5_config) # type: ignore
pytest.tiny_t5_tokenizer = tiny_t5_tokenizer_helper() # type: ignore
+ pytest.tiny_mistral_config = tiny_mistral_config_helper() # type: ignore
+ pytest.tiny_mistral_model = tiny_mistral_model_helper(pytest.tiny_mistral_config) # type: ignore
+ pytest.tiny_mistral_tokenizer = tiny_mistral_tokenizer_helper() # type: ignore
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
diff --git a/tests/datasets/test_add_dataset_transform.py b/tests/datasets/test_add_dataset_transform.py
deleted file mode 100644
index d7a545a33b..0000000000
--- a/tests/datasets/test_add_dataset_transform.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-import pytest
-from torchvision import transforms
-
-from composer.datasets.synthetic import SyntheticPILDataset
-from composer.datasets.utils import add_vision_dataset_transform
-
-image_size = 32
-
-
-def generate_synthetic_dataset(data_transforms):
- return SyntheticPILDataset(total_dataset_size=1000,
- data_shape=[image_size, image_size],
- num_classes=2,
- transform=data_transforms)
-
-
-def generate_default_transforms():
- return transforms.Compose([transforms.RandomCrop(32), transforms.ToTensor(), transforms.RandomRotation(5)])
-
-
-def generate_composition_no_tensor():
- return transforms.Compose(
- [transforms.RandomCrop(32),
- transforms.RandomHorizontalFlip(),
- transforms.RandomRotation(5)])
-
-
-@pytest.mark.parametrize('is_tensor_transform,index', [(False, 1), (True, 2)])
-def test_pre_post_to_tensor_compose(is_tensor_transform, index):
- dataset = generate_synthetic_dataset(generate_default_transforms())
- add_vision_dataset_transform(dataset, transforms.RandomAutocontrast(), is_tensor_transform=is_tensor_transform)
- assert dataset.transform is not None
- assert type(dataset.transform.transforms[index]) == transforms.RandomAutocontrast
-
-
-@pytest.mark.parametrize('is_tensor_transform,index', [(False, 0), (True, 1)])
-def test_pre_post_to_tensor(is_tensor_transform, index):
- dataset = generate_synthetic_dataset(transforms.ToTensor())
- add_vision_dataset_transform(dataset, transforms.RandomAutocontrast(), is_tensor_transform=is_tensor_transform)
- assert dataset.transform is not None
- assert type(dataset.transform.transforms[index]) == transforms.RandomAutocontrast
-
-
-@pytest.mark.parametrize('data_transforms', [(generate_composition_no_tensor()), (transforms.RandomHorizontalFlip())])
-def test_default_to_append(data_transforms):
- dataset = generate_synthetic_dataset(data_transforms)
- add_vision_dataset_transform(dataset, transforms.RandomAutocontrast())
- assert dataset.transform is not None
- assert type(dataset.transform.transforms[-1]) == transforms.RandomAutocontrast
-
-
-def test_add_to_none_transform():
- dataset = generate_synthetic_dataset(None)
- add_vision_dataset_transform(dataset, transforms.RandomAutocontrast())
- assert type(dataset.transform) == transforms.RandomAutocontrast
diff --git a/tests/datasets/test_cifar.py b/tests/datasets/test_cifar.py
deleted file mode 100644
index 6eac6e2ebf..0000000000
--- a/tests/datasets/test_cifar.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-import pytest
-
-from composer.datasets import build_cifar10_dataloader, build_synthetic_cifar10_dataloader
-
-
-@pytest.mark.skip # Download is flaky and test is not critical
-@pytest.mark.parametrize('is_train', [False, True])
-@pytest.mark.parametrize('synthetic', [pytest.param(False, marks=pytest.mark.daily), True])
-def test_cifar10_shape_length(is_train, synthetic):
- batch_size = 1
-
- if synthetic:
- dataspec = build_synthetic_cifar10_dataloader(global_batch_size=batch_size, is_train=is_train)
- else:
- dataspec = build_cifar10_dataloader(datadir='/tmp', global_batch_size=batch_size, is_train=is_train)
-
- samples = list(dataspec.dataloader)
- if is_train:
- assert len(samples) == 50000 // batch_size
- else:
- assert len(samples) == 10000 // batch_size
-
- assert samples[0][0].shape == (1, 3, 32, 32)
diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py
deleted file mode 100644
index 720edce59b..0000000000
--- a/tests/datasets/test_dataset_utils.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-from typing import List, Tuple
-
-import numpy as np
-import pytest
-import torch
-from PIL import Image
-
-from composer.datasets.utils import pil_image_collate
-
-
-@pytest.fixture
-def num_samples():
- return 4
-
-
-@pytest.fixture
-def image_size():
- return (16, 16)
-
-
-@pytest.fixture
-def pil_image_list(num_samples: int, image_size: Tuple[int, int]):
- return [Image.new(mode='RGB', size=image_size, color=(i, i, i)) for i in range(num_samples)]
-
-
-@pytest.fixture
-def pil_target_list(num_samples: int, image_size: Tuple[int, int]):
- return [Image.new(mode='L', size=image_size, color=i) for i in range(num_samples)]
-
-
-@pytest.fixture
-def correct_image_tensor(num_samples: int, image_size: Tuple[int, int]):
- return torch.arange(num_samples).expand(3, *image_size, -1).permute(3, 0, 1, 2)
-
-
-@pytest.fixture
-def scalar_target_list(num_samples: int):
- return np.arange(num_samples)
-
-
-def test_scalar_target_collate(pil_image_list: List[Image.Image], scalar_target_list: np.ndarray,
- correct_image_tensor: torch.Tensor):
- batch = [(img, target) for img, target in zip(pil_image_list, scalar_target_list)]
- image_tensor, target_tensor = pil_image_collate(batch=batch)
-
- correct_target_tensor = torch.arange(correct_image_tensor.shape[0])
-
- assert torch.all(image_tensor == correct_image_tensor) and torch.all(target_tensor == correct_target_tensor)
-
-
-def test_image_target_collate(pil_image_list: List[Image.Image], pil_target_list: List[Image.Image],
- correct_image_tensor):
- batch = [(img, target) for img, target in zip(pil_image_list, pil_target_list)]
- image_tensor, target_tensor = pil_image_collate(
- batch=batch) # type: ignore "Image" is incompatible with "ndarray[Unknown, Unknown]"
-
- assert torch.all(image_tensor == correct_image_tensor) and torch.all(target_tensor == correct_image_tensor[:, 0])
diff --git a/tests/datasets/test_ffcv_utils.py b/tests/datasets/test_ffcv_utils.py
deleted file mode 100644
index 3614d73387..0000000000
--- a/tests/datasets/test_ffcv_utils.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-import os
-import pathlib
-
-import pytest
-
-from composer.datasets.ffcv_utils import write_ffcv_dataset
-from composer.datasets.synthetic import SyntheticDataLabelType, SyntheticPILDataset
-
-
-@pytest.mark.vision
-def test_write_ffcv_dataset(tmp_path: pathlib.Path):
- dataset = SyntheticPILDataset(total_dataset_size=1,
- num_classes=1,
- data_shape=[1, 1, 3],
- label_type=SyntheticDataLabelType.CLASSIFICATION_INT,
- num_unique_samples_to_create=1)
- output_file = str(tmp_path / 'ffcv')
- write_ffcv_dataset(dataset, write_path=output_file, num_workers=1)
- assert os.path.exists(output_file)
diff --git a/tests/datasets/test_in_context_learning_datasets.py b/tests/datasets/test_in_context_learning_datasets.py
index 2e9a461fcf..063f7215bc 100644
--- a/tests/datasets/test_in_context_learning_datasets.py
+++ b/tests/datasets/test_in_context_learning_datasets.py
@@ -9,15 +9,28 @@
import pytest
import torch
-import transformers
from torch.utils.data import DataLoader
-from transformers import AutoTokenizer
from composer import Evaluator
from composer.core import DataSpec
-from composer.datasets.in_context_learning_evaluation import (InContextLearningCodeEvalDataset,
- _get_fewshot_sample_idxs, _make_padded_input,
- get_icl_task_dataloader)
+
+# isort: off
+from composer.datasets.in_context_learning_evaluation import (
+ InContextLearningCodeEvalDataset,
+ InContextLearningDataset,
+ InContextLearningMultipleChoiceTaskDataset,
+ InContextLearningQATaskDataset,
+ InContextLearningSchemaTaskDataset,
+ _get_continuation_span,
+ _get_fewshot_sample_idxs,
+ _make_padded_input,
+ _tokenizer_needs_prefix_space,
+ _trim_context,
+ get_icl_task_dataloader,
+ strip_data,
+)
+# isort: on
+from composer.datasets.utils import MultiTokenEOSCriteria
from composer.loggers import InMemoryLogger
from composer.metrics import (InContextLearningCodeEvalAccuracy, InContextLearningLMAccuracy,
InContextLearningMultipleChoiceAccuracy, InContextLearningQAAccuracy)
@@ -27,19 +40,122 @@
from tests.common import device, world_size
+def test_strip_data():
+ data_to_strip = {'strip_data': ' boo! \n', 'has_space': ' wa hoo!', 'end_space': 'yoohoo! '}
+ stripped_data = strip_data(data_to_strip)
+ for k, v in stripped_data.items():
+ assert k in data_to_strip
+ assert not v[0].isspace()
+ assert not v[-1].isspace()
+
+
+@pytest.mark.skip(reason="Currently don't have a tokenizer that satisfies this test")
+def test_tokenizer_needs_prefix_space_when_space_not_needed(tiny_gpt2_tokenizer):
+ assert not _tokenizer_needs_prefix_space(tiny_gpt2_tokenizer)
+
+
+def test_tokenizer_needs_prefix_space_when_space_needed():
+ transformers = pytest.importorskip('transformers')
+ tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/opt-125m',
+ use_fast=False) # type: ignore reportUnboundVariable
+ assert _tokenizer_needs_prefix_space(tokenizer)
+
+
+def test_trim_context():
+ context = [0] * 99 + [1] * 2037
+ continuation = [2] * 10
+ max_seq_len = 2048
+ trimmed_context = _trim_context(context, continuation, max_seq_len=max_seq_len)
+ assert len(trimmed_context) == 2038
+ assert trimmed_context[0] == 0
+ assert trimmed_context[1] == 1
+
+
+def test_trim_context_no_continuation():
+ context = [0] * 2048
+ max_seq_len = 2048
+ trimmed_context = _trim_context(context, [], max_seq_len=max_seq_len)
+ assert len(trimmed_context) == 2048
+ context = [0] * 3000 + [1]
+ max_seq_len = 2048
+ trimmed_context = _trim_context(context, [], max_seq_len=max_seq_len)
+ assert len(trimmed_context) == 2048
+ assert trimmed_context[-1] == 1
+
+
+def test_get_continuation_span():
+ context = [0] * 200
+ continuation = [1] * 3
+ cont_span = _get_continuation_span(context, continuation)
+ assert torch.all(torch.eq(cont_span, torch.tensor([200, 201, 202])))
+ continuation = [1]
+ cont_span = _get_continuation_span(context, continuation)
+ assert torch.all(torch.eq(cont_span, torch.tensor([200])))
+
+
+@pytest.mark.parametrize('padding_side', ['left', 'right', 'middle'])
+def test_make_padding(tiny_gpt2_tokenizer, padding_side):
+ context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids']
+ padding_id = tiny_gpt2_tokenizer.eos_token_id
+
+ error_context = contextlib.nullcontext() if padding_side in {'left', 'right'} else pytest.raises(ValueError)
+
+ with error_context:
+ input_ids = _make_padded_input(context, [], 2048, padding_id, padding_side=padding_side)
+
+ if padding_side == 'left':
+ assert input_ids[0] == tiny_gpt2_tokenizer.eos_token_id
+ assert input_ids[48:].tolist() == context
+ elif padding_side == 'right':
+ assert input_ids[-1] == tiny_gpt2_tokenizer.eos_token_id
+ assert input_ids[:-48].tolist() == context
+
+
+def test_batch_padding_logic_no_padding(tiny_gpt2_tokenizer):
+ continuation = tiny_gpt2_tokenizer(' dog' * 2000)['input_ids']
+ context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids']
+ max_seq_len = 2048
+ trimmed_context = _trim_context(context, continuation, max_seq_len)
+ continuation_spans = _get_continuation_span(trimmed_context, continuation)
+ padded_input = _make_padded_input(trimmed_context,
+ continuation,
+ max_seq_len,
+ tiny_gpt2_tokenizer.pad_token_id,
+ padding_side='right')
+ assert continuation_spans[0] == 48 and continuation_spans[-1] == 2047
+ assert len(padded_input) == 2048
+ assert tiny_gpt2_tokenizer.pad_token_id not in padded_input
+
+
+def test_batch_padding_logic_with_padding(tiny_gpt2_tokenizer):
+ continuation = tiny_gpt2_tokenizer(' dog' * 200)['input_ids']
+ context = tiny_gpt2_tokenizer(' cat' * 200)['input_ids']
+ max_seq_len = 2048
+ trimmed_context = _trim_context(context, continuation, max_seq_len)
+ continuation_spans = _get_continuation_span(trimmed_context, continuation)
+ padded_input = _make_padded_input(trimmed_context,
+ continuation,
+ max_seq_len,
+ tiny_gpt2_tokenizer.pad_token_id,
+ padding_side='right')
+ assert continuation_spans[0] == 200 and continuation_spans[-1] == 399
+ assert len(padded_input) == 2048
+ assert padded_input[-1] == tiny_gpt2_tokenizer.pad_token_id
+
+
def test_fewshot_sample_idxs():
rng = random.Random(1234)
- fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=5, num_fewshot=4, sample_idx=4, rng=rng)
+ fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=5, num_fewshot=4, example_idx=4, rng=rng)
assert fewshot_idxs == {0, 1, 2, 3}
- fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=5, num_fewshot=5, sample_idx=4, rng=rng)
+ fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=5, num_fewshot=5, example_idx=4, rng=rng)
assert fewshot_idxs == {0, 1, 2, 3}
- fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=5, num_fewshot=500, sample_idx=4, rng=rng)
+ fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=5, num_fewshot=500, example_idx=4, rng=rng)
assert fewshot_idxs == {0, 1, 2, 3}
- fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=10, num_fewshot=7, sample_idx=4, rng=rng)
+ fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=10, num_fewshot=7, example_idx=4, rng=rng)
assert len(fewshot_idxs) == 7 and 4 not in fewshot_idxs
@@ -66,30 +182,667 @@ def test_fewshot_sample_idxs_randomness():
assert rng_1_sample_2 != rng_3_sample_2
-def test_batch_padding_logic(tiny_gpt2_tokenizer):
- continuation = tiny_gpt2_tokenizer(' dog' * 2000)['input_ids']
- context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids']
- _, continuation_spans = _make_padded_input(context, continuation, 2048, tiny_gpt2_tokenizer.eos_token_id)
- # the context (of len 2000) gets clipped to len 48 so that the whole continuation can fit
- assert continuation_spans[0] == 48 and continuation_spans[-1] == 2047
+@pytest.mark.filterwarnings(
+ r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning')
+def test_update_generation_kwargs(tiny_gpt2_tokenizer, tmp_path):
+ tokenizer = tiny_gpt2_tokenizer
+ seqlen = 2048
+ num_fewshot = 0
+ prompt_string = ''
+ hf_loading_vars = {
+ 'split': 'test',
+ 'name': 'invoker',
+ }
+ hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']}
+ gen_kwargs = {'test_arg1': 1, 'test_arg2': 2}
+ dl = InContextLearningDataset(dataset_uri='hf://mosaicml/test_dataset',
+ tokenizer=tokenizer,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=num_fewshot,
+ fewshot_random_seed=1,
+ prompt_string=prompt_string,
+ example_delimiter='\n',
+ prelimiter='Orbs: ',
+ continuation_delimiter='\nSpell:',
+ destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'),
+ hf_loading_vars=hf_loading_vars,
+ hf_parsing_map=hf_parsing_map,
+ generation_kwargs=gen_kwargs)
+ assert dl.base_batch['generation_kwargs'] == {'test_arg1': 1, 'test_arg2': 2}
+
+
+def test_stop_sequences_criteria(tiny_gpt2_tokenizer):
+ pytest.importorskip('transformers')
+ eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2)
+ seq1 = tiny_gpt2_tokenizer('Dogs are furry')['input_ids']
+ seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids']
+ seq1 = [tiny_gpt2_tokenizer.pad_token_id] * (len(seq2) - len(seq1)) + seq1
+ input_ids = torch.LongTensor([seq1, seq2])
+ assert not eos_criteria(input_ids, None) # pyright: ignore[reportGeneralTypeIssues]
+
+ eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2)
+ seq1 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids']
+ seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids']
+ input_ids = torch.LongTensor([seq1, seq2])
+ assert eos_criteria(input_ids, None) # pyright: ignore[reportGeneralTypeIssues]
+
+
+def test_stop_sequences_criteria_sentencepiece(tiny_llama_tokenizer):
+ pytest.importorskip('datasets')
-@pytest.mark.parametrize('padding_side', ['left', 'right', 'middle'])
-def test_make_padding(tiny_gpt2_tokenizer, padding_side):
- context = tiny_gpt2_tokenizer(' cat' * 2000)['input_ids']
- padding_id = tiny_gpt2_tokenizer.eos_token_id
+ tokenizer = tiny_llama_tokenizer
+ eos_criteria = MultiTokenEOSCriteria('\n\n', tokenizer, 2)
+ seq1 = tokenizer('\n\nDogs')['input_ids'] # check to make sure starting with the stop sequence doesnt break it
+ seq2 = tokenizer('Dogs are furry\n\n')['input_ids']
+ seq1 = [tokenizer.eos_token_id] * (len(seq2) - len(seq1)) + seq1
+ input_ids = torch.LongTensor([seq1, seq2])
+ assert not eos_criteria(input_ids, None) # pyright: ignore[reportGeneralTypeIssues]
- error_context = contextlib.nullcontext() if padding_side in {'left', 'right'} else pytest.raises(ValueError)
+ eos_criteria = MultiTokenEOSCriteria('\n\n', tokenizer, 2)
+ seq1 = tokenizer('Dogs are furry\n\n')['input_ids']
+ seq2 = tokenizer('Dogs are furry\n\n')['input_ids']
+ input_ids = torch.LongTensor([seq1, seq2])
+ assert eos_criteria(input_ids, None) # pyright: ignore[reportGeneralTypeIssues]
- with error_context:
- input_ids, _ = _make_padded_input(context, [], 2048, padding_id, padding_side=padding_side)
- if padding_side == 'left':
- assert input_ids[0] == tiny_gpt2_tokenizer.eos_token_id
- assert input_ids[48:].tolist() == context
- elif padding_side == 'right':
- assert input_ids[-1] == tiny_gpt2_tokenizer.eos_token_id
- assert input_ids[:-48].tolist() == context
+@pytest.mark.filterwarnings(
+ r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning')
+def test_update_generation_kwargs_no_kwargs(tiny_gpt2_tokenizer, tmp_path):
+ tokenizer = tiny_gpt2_tokenizer
+ seqlen = 2048
+ num_fewshot = 0
+ prompt_string = ''
+ hf_loading_vars = {
+ 'split': 'test',
+ 'name': 'invoker',
+ }
+ hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']}
+
+ dl = InContextLearningDataset(dataset_uri='hf://mosaicml/test_dataset',
+ tokenizer=tokenizer,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=num_fewshot,
+ fewshot_random_seed=1,
+ prompt_string=prompt_string,
+ example_delimiter='\n',
+ prelimiter='Orbs: ',
+ continuation_delimiter='\nSpell:',
+ destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'),
+ hf_loading_vars=hf_loading_vars,
+ hf_parsing_map=hf_parsing_map)
+ assert not 'generation_kwargs' in dl.base_batch
+
+
+def test_update_generation_kwargs_no_kwargs_qa_dataset(tmp_path):
+ pytest.importorskip('datasets')
+ local_data = os.path.join(os.path.dirname(__file__), 'local_data')
+ dataset_uri = f'{local_data}/triviaqa_small.jsonl'
+ transformers = pytest.importorskip('transformers')
+ tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/opt-125m') # type: ignore reportUnboundVariable
+
+ tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
+ gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
+ dl = InContextLearningQATaskDataset(dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ max_seq_len=1024,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=0,
+ fewshot_random_seed=1234,
+ prompt_string='',
+ example_delimiter='\n',
+ continuation_delimiter=': ',
+ destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'),
+ generation_kwargs=None)
+ assert len(dl.base_batch['generation_kwargs']) == 3
+
+
+def test_update_generation_kwargs_with_kwargs_qa_dataset(tmp_path):
+ pytest.importorskip('datasets')
+ local_data = os.path.join(os.path.dirname(__file__), 'local_data')
+ dataset_uri = f'{local_data}/triviaqa_small.jsonl'
+ transformers = pytest.importorskip('transformers')
+ tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/opt-125m') # type: ignore reportUnboundVariable
+
+ tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
+ gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
+ dl = InContextLearningQATaskDataset(dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ max_seq_len=1024,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=0,
+ fewshot_random_seed=1234,
+ prompt_string='',
+ example_delimiter='\n',
+ continuation_delimiter=': ',
+ destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'),
+ generation_kwargs={'temperature': 0.9})
+ assert 'generation_kwargs' in dl.base_batch
+ assert dl.base_batch['generation_kwargs']['temperature'] == 0.9
+ assert len(dl.base_batch['generation_kwargs']) == 4
+
+
+@pytest.mark.filterwarnings(
+ r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning')
+def test_construct_context(tiny_gpt2_tokenizer, tmp_path):
+ tokenizer = tiny_gpt2_tokenizer
+ seqlen = 2048
+ num_fewshot = 0
+ prompt_string = ''
+ hf_loading_vars = {
+ 'split': 'test',
+ 'name': 'invoker',
+ }
+ hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']}
+
+ dl = InContextLearningDataset(dataset_uri='hf://mosaicml/test_dataset',
+ tokenizer=tokenizer,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=num_fewshot,
+ fewshot_random_seed=1,
+ prompt_string=prompt_string,
+ example_delimiter='\n',
+ prelimiter='Orbs: ',
+ continuation_delimiter='\nSpell: ',
+ destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'),
+ hf_loading_vars=hf_loading_vars,
+ hf_parsing_map=hf_parsing_map)
+ constructed_context = dl.construct_context({'context': 'quas quas exort', 'answer': 'ice wall'})
+ assert constructed_context == 'Orbs: quas quas exort\nSpell: '
+ constructed_context = dl.construct_context({'context': 'quas quas exort', 'answer': 'ice wall'}, add_answer=True)
+ assert constructed_context == 'Orbs: quas quas exort\nSpell: ice wall'
+ constructed_context = dl.construct_context({
+ 'context': 'quas quas exort',
+ 'answer': 'ice wall'
+ },
+ preceding_text='The harsh White Waste beckons!',
+ add_answer=True)
+ assert constructed_context == '\nOrbs: quas quas exort\nSpell: ice wall'
+
+
+@pytest.mark.filterwarnings(
+ r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning')
+def test_get_answer_from_example(tiny_gpt2_tokenizer, tmp_path):
+ tokenizer = tiny_gpt2_tokenizer
+ seqlen = 2048
+ num_fewshot = 0
+ prompt_string = ''
+ hf_loading_vars = {
+ 'split': 'test',
+ 'name': 'invoker',
+ }
+ hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']}
+
+ dl = InContextLearningDataset(dataset_uri='hf://mosaicml/test_dataset',
+ tokenizer=tokenizer,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=num_fewshot,
+ fewshot_random_seed=1,
+ prompt_string=prompt_string,
+ example_delimiter='\n',
+ prelimiter='Orbs: ',
+ continuation_delimiter='\nSpell:',
+ destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'),
+ hf_loading_vars=hf_loading_vars,
+ hf_parsing_map=hf_parsing_map)
+ answer = dl.get_answer_from_example({'context': 'wex exort exort', 'answer': 'alacrity'})
+ assert answer == ' alacrity'
+
+
+@pytest.mark.filterwarnings(
+ r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning')
+def test_fix_eos_on_preamble(tmp_path):
+ transformers = pytest.importorskip('transformers')
+ tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/opt-125m',
+ use_fast=False) # type: ignore reportUnboundVariable
+ seqlen = 2048
+ num_fewshot = 0
+ prompt_string = ''
+ hf_loading_vars = {
+ 'split': 'test',
+ 'name': 'invoker',
+ }
+ hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']}
+
+ dl = InContextLearningDataset(dataset_uri='hf://mosaicml/test_dataset',
+ tokenizer=tokenizer,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=num_fewshot,
+ fewshot_random_seed=1,
+ prompt_string=prompt_string,
+ example_delimiter='\n',
+ prelimiter='Orbs: ',
+ continuation_delimiter='\nSpell:',
+ destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'),
+ hf_loading_vars=hf_loading_vars,
+ hf_parsing_map=hf_parsing_map)
+ preamble = 'blah blah blah.'
+ tokenized_preamble = tokenizer.encode(preamble)
+ tokenized_preamble += [tokenizer.eos_token_id]
+ fixed_preamble = dl._fix_eos_on_preamble(tokenized_preamble)
+ assert tokenized_preamble[:-1] == fixed_preamble
+ assert fixed_preamble[-1] != tokenizer.eos_token_id
+
+
+@pytest.mark.filterwarnings(
+ r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning')
+def test_tokenize_example_with_tokenize_labels(tiny_gpt2_tokenizer, tmp_path):
+ tokenizer = tiny_gpt2_tokenizer
+ seqlen = 2048
+ num_fewshot = 0
+ prompt_string = ''
+ hf_loading_vars = {
+ 'split': 'test',
+ 'name': 'invoker',
+ }
+ hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']}
+
+ dl = InContextLearningDataset(dataset_uri='hf://mosaicml/test_dataset',
+ tokenizer=tokenizer,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=num_fewshot,
+ fewshot_random_seed=1,
+ prompt_string=prompt_string,
+ example_delimiter='\n',
+ prelimiter='Orbs: ',
+ continuation_delimiter='\nSpell: ',
+ destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'),
+ hf_loading_vars=hf_loading_vars,
+ hf_parsing_map=hf_parsing_map,
+ tokenize_labels=True)
+ tokenized_example = dl.tokenize_example('What spell does this invoke? ', 'exort exort wex\nSpell: ',
+ {'answer': ' Meatball'})
+ tokenized_input = [2061, 4822, 857, 428, 26342, 30, 220, 1069, 419, 409, 419, 356, 87, 198, 31221, 25, 19145, 1894]
+ assert tokenized_example['context'][:len(tokenized_input)].tolist() == tokenized_input
+ assert tokenized_example['context'][-1] == tokenizer.eos_token_id
+ assert type(tokenized_example['answer'][0]) == int
+ assert len(tokenized_example['context']) == seqlen
+ assert 'continuation_indices' in tokenized_example
+
+
+@pytest.mark.filterwarnings(
+ r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning')
+def test_tokenize_example_with_no_tokenize_labels(tiny_gpt2_tokenizer, tmp_path):
+ tokenizer = tiny_gpt2_tokenizer
+ seqlen = 2048
+ num_fewshot = 0
+ prompt_string = ''
+ hf_loading_vars = {
+ 'split': 'test',
+ 'name': 'invoker',
+ }
+ hf_parsing_map = {'context': ['quas', 'wex', 'exort'], 'answer': ['spell']}
+
+ dl = InContextLearningDataset(dataset_uri='hf://mosaicml/test_dataset',
+ tokenizer=tokenizer,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=num_fewshot,
+ fewshot_random_seed=1,
+ prompt_string=prompt_string,
+ example_delimiter='\n',
+ prelimiter='Orbs: ',
+ continuation_delimiter='\nSpell: ',
+ destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'),
+ hf_loading_vars=hf_loading_vars,
+ hf_parsing_map=hf_parsing_map,
+ tokenize_labels=False)
+ tokenized_example = dl.tokenize_example('What spell does this invoke? ', 'exort exort wex\nSpell: ',
+ {'answer': ' Meatball'})
+ tokenized_input = [2061, 4822, 857, 428, 26342, 30, 220, 1069, 419, 409, 419, 356, 87, 198, 31221, 25]
+ assert tokenized_example['context'][:len(tokenized_input)].tolist() == tokenized_input
+ assert tokenized_example['context'][-1] == tokenizer.eos_token_id
+ assert len(tokenized_example['context']) == seqlen
+ assert type(tokenized_example['answer']) == str
+
+
+def test_qa_set_cot_no_cot(tmp_path):
+ pytest.importorskip('datasets')
+ local_data = os.path.join(os.path.dirname(__file__), 'local_data')
+ dataset_uri = f'{local_data}/triviaqa_small.jsonl'
+ transformers = pytest.importorskip('transformers')
+ tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/opt-125m') # type: ignore reportUnboundVariable
+
+ tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
+ gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
+ dl = InContextLearningQATaskDataset(
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ max_seq_len=1024,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=0,
+ fewshot_random_seed=1234,
+ prompt_string='',
+ example_delimiter='\n',
+ continuation_delimiter=': ',
+ destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'),
+ )
+ assert not dl.has_cot
+
+
+def test_qa_set_cot_has_cot(tmp_path):
+ pytest.importorskip('datasets')
+ local_data = os.path.join(os.path.dirname(__file__), 'local_data')
+ dataset_uri = f'{local_data}/gsm8k_small.jsonl'
+ transformers = pytest.importorskip('transformers')
+ tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/opt-125m') # type: ignore reportUnboundVariable
+
+ tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
+ gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
+ dl = InContextLearningQATaskDataset(
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ max_seq_len=1024,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=0,
+ fewshot_random_seed=1234,
+ prompt_string='',
+ example_delimiter='\n',
+ continuation_delimiter=': ',
+ destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'),
+ )
+ assert dl.has_cot
+
+
+def test_qa_get_max_answer_length(tiny_gpt2_tokenizer, tmp_path):
+ local_data = os.path.join(os.path.dirname(__file__), 'local_data')
+ dataset_uri = f'{local_data}/triviaqa_small.jsonl'
+ tokenizer = tiny_gpt2_tokenizer
+
+ tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
+ gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
+ dl = InContextLearningQATaskDataset(
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ max_seq_len=1024,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=0,
+ fewshot_random_seed=1234,
+ prompt_string='',
+ example_delimiter='',
+ continuation_delimiter='',
+ cot_delimiter='',
+ destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'),
+ )
+ # empirical number from the small test dataset
+ assert dl.max_answer_length == 7
+
+
+def test_qa_get_answer_from_example_with_no_cot(tmp_path, tiny_gpt2_tokenizer):
+ pytest.importorskip('datasets')
+ local_data = os.path.join(os.path.dirname(__file__), 'local_data')
+ dataset_uri = f'{local_data}/triviaqa_small.jsonl'
+
+ tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
+ gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
+ dl = InContextLearningQATaskDataset(
+ dataset_uri=dataset_uri,
+ tokenizer=tiny_gpt2_tokenizer,
+ max_seq_len=1024,
+ pad_tok_id=tiny_gpt2_tokenizer.eos_token_id,
+ num_fewshot=0,
+ fewshot_random_seed=1234,
+ prompt_string='',
+ example_delimiter='\n',
+ continuation_delimiter=': ',
+ cot_delimiter=' ### ',
+ destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'),
+ )
+ answer = dl.get_answer_from_example({
+ 'context': 'empty',
+ 'answer': 'this is the correct answer',
+ 'chain_of_thought': "Let's think step by step. "
+ })
+ assert answer == 'this is the correct answer'
+
+
+def test_qa_get_answer_from_example_with_cot(tmp_path, tiny_gpt2_tokenizer):
+ pytest.importorskip('datasets')
+ local_data = os.path.join(os.path.dirname(__file__), 'local_data')
+ dataset_uri = f'{local_data}/triviaqa_small.jsonl'
+
+ tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
+ gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
+ dl = InContextLearningQATaskDataset(
+ dataset_uri=dataset_uri,
+ tokenizer=tiny_gpt2_tokenizer,
+ max_seq_len=1024,
+ pad_tok_id=tiny_gpt2_tokenizer.eos_token_id,
+ num_fewshot=0,
+ fewshot_random_seed=1234,
+ prompt_string='',
+ example_delimiter='\n',
+ continuation_delimiter=': ',
+ cot_delimiter=' ### ',
+ destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'),
+ )
+ dl.has_cot = True
+ answer = dl.get_answer_from_example({
+ 'context': 'empty',
+ 'answer': 'this is the correct answer',
+ 'chain_of_thought': "Let's think step by step. "
+ })
+ assert answer == "Let's think step by step. ### this is the correct answer"
+
+
+def test_qa_tokenize_example(tiny_gpt2_tokenizer, tmp_path):
+ pytest.importorskip('datasets')
+ local_data = os.path.join(os.path.dirname(__file__), 'local_data')
+ dataset_uri = f'{local_data}/triviaqa_small.jsonl'
+
+ tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
+ gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
+ dl = InContextLearningQATaskDataset(
+ dataset_uri=dataset_uri,
+ tokenizer=tiny_gpt2_tokenizer,
+ max_seq_len=1024,
+ pad_tok_id=tiny_gpt2_tokenizer.eos_token_id,
+ num_fewshot=0,
+ fewshot_random_seed=1234,
+ prompt_string='',
+ example_delimiter='\n',
+ continuation_delimiter=': ',
+ cot_delimiter=' ### ',
+ destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'),
+ )
+ dl.has_cot = True
+ tokenized_example = dl.tokenize_example(
+ 'starting prompt', 'a context', {
+ 'context': 'empty',
+ 'answer': 'this is the correct answer',
+ 'aliases': ['this is the right answer', 'this is the best answer'],
+ 'chain_of_thought': "Let's think step by step. "
+ })
+ assert 'aliases' in tokenized_example
+ assert tokenized_example['aliases'] == ['this is the right answer', 'this is the best answer']
+
+
+def test_code_adjust_padding(tiny_gpt2_tokenizer, tmp_path):
+ local_data = os.path.join(os.path.dirname(__file__), 'local_data')
+ dataset_uri = f'{local_data}/human_eval_small.jsonl'
+ tokenizer = tiny_gpt2_tokenizer
+ seqlen = 2048
+ num_fewshot = 0
+ prompt_string = ''
+ gen_kwargs = {'temperature': .9, 'top_p': .95, 'num_beams': 9000}
+
+ dl = InContextLearningCodeEvalDataset(
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=num_fewshot,
+ fewshot_random_seed=1,
+ prompt_string=prompt_string,
+ example_delimiter='\n',
+ prelimiter='Code start:',
+ continuation_delimiter='\nPlease code:',
+ destination_path=str(tmp_path / 'test_human_eval_small.jsonl'),
+ generation_kwargs=gen_kwargs,
+ generations_per_sample=10,
+ )
+
+ assert all(len(data['prompt']) == 148 for data in dl.dataset) # pyright: ignore [reportGeneralTypeIssues]
+
+
+def test_code_update_gen_kwargs(tiny_gpt2_tokenizer, tmp_path):
+ local_data = os.path.join(os.path.dirname(__file__), 'local_data')
+ dataset_uri = f'{local_data}/human_eval_small.jsonl'
+ tokenizer = tiny_gpt2_tokenizer
+ seqlen = 2048
+ num_fewshot = 0
+ prompt_string = ''
+ gen_kwargs = {'temperature': .9, 'top_p': .95, 'num_beams': 9000}
+
+ dl = InContextLearningCodeEvalDataset(
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=num_fewshot,
+ fewshot_random_seed=1,
+ prompt_string=prompt_string,
+ example_delimiter='\n',
+ prelimiter='Code start:',
+ continuation_delimiter='\nPlease code:',
+ destination_path=str(tmp_path / 'test_human_eval_small.jsonl'),
+ generation_kwargs=gen_kwargs,
+ generations_per_sample=10,
+ )
+ assert dl.base_batch['generation_kwargs']['num_beams'] == 9000
+ assert dl.base_batch['generation_kwargs']['top_p'] == .95
+ assert dl.base_batch['generation_kwargs']['temperature'] == .9
+ assert dl.base_batch['generation_kwargs']['do_sample'] == True
+
+
+def test_mc_tokenize_example(tiny_gpt2_tokenizer, tmp_path):
+ local_data = os.path.join(os.path.dirname(__file__), 'local_data')
+ dataset_uri = f'{local_data}/mmlu_small.jsonl'
+ tokenizer = tiny_gpt2_tokenizer
+ seqlen = 2048
+ num_fewshot = 0
+ prompt_string = ''
+ seqlen = 2048
+ dl = InContextLearningMultipleChoiceTaskDataset(
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=num_fewshot,
+ fewshot_random_seed=1,
+ prompt_string=prompt_string,
+ example_delimiter='\n',
+ continuation_delimiter=' ### ',
+ destination_path=str(tmp_path / 'test_human_eval_small.jsonl'),
+ )
+ example = {
+ 'context': "Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: ",
+ 'choices': ['A', 'B', 'C', 'D'],
+ 'gold': 2
+ }
+ tokenized_example = dl.tokenize_example(prompt_and_fewshot='Answer the following: ',
+ ctxt=example['context'],
+ example=example)
+ unpadded_queries = [context[context != tokenizer.eos_token_id] for context in tokenized_example['query']]
+ untokenized_inputs = [tokenizer.decode(unpadded_input) for unpadded_input in unpadded_queries]
+ correct_output = [
+ "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: A",
+ "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: B",
+ "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: C",
+ "Answer the following: Who's the best eval researcher?\n A. Jeremy\n B. Tessa\n C. Max\n D. Other\nAnswer: D"
+ ]
+ assert untokenized_inputs == correct_output
+
+
+def test_schema_construct_context(tiny_gpt2_tokenizer, tmp_path):
+ local_data = os.path.join(os.path.dirname(__file__), 'local_data')
+ dataset_uri = f'{local_data}/winograd_small.jsonl'
+ tokenizer = tiny_gpt2_tokenizer
+ seqlen = 2048
+ num_fewshot = 0
+ seqlen = 2048
+ dl = InContextLearningSchemaTaskDataset(
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=num_fewshot,
+ fewshot_random_seed=1,
+ prompt_string='',
+ example_delimiter='\n',
+ continuation_delimiter=' ### ',
+ destination_path=str(tmp_path / 'test_human_eval_small.jsonl'),
+ )
+ example = {'context_options': ['cont one', 'cont two'], 'gold': 0, 'continuation': 'this is a continuation'}
+ constructed_context = dl.construct_context(example)
+ assert constructed_context == 'cont one ### this is a continuation'
+ constructed_context = dl.construct_context(example, preceding_text='text')
+ assert constructed_context == '\ncont one ### this is a continuation'
+
+
+def test_schema_construct_multiple_contexts(tiny_gpt2_tokenizer, tmp_path):
+ local_data = os.path.join(os.path.dirname(__file__), 'local_data')
+ dataset_uri = f'{local_data}/winograd_small.jsonl'
+ tokenizer = tiny_gpt2_tokenizer
+ seqlen = 2048
+ num_fewshot = 0
+ prompt_string = ''
+ seqlen = 2048
+ dl = InContextLearningSchemaTaskDataset(
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=num_fewshot,
+ fewshot_random_seed=1,
+ prompt_string=prompt_string,
+ example_delimiter='\n',
+ continuation_delimiter=' ### ',
+ destination_path=str(tmp_path / 'test_human_eval_small.jsonl'),
+ )
+ example = {'context_options': ['cont one', 'cont two'], 'gold': 0, 'continuation': 'this is a continuation'}
+ constructed_contexts = dl._construct_multiple_contexts(example)
+ assert constructed_contexts == ['cont one', 'cont two']
+ constructed_contexts = dl._construct_multiple_contexts(example, preceding_text='some text')
+ assert constructed_contexts == ['\ncont one ###', '\ncont two ###']
+
+
+def test_schema_tokenize_example(tiny_gpt2_tokenizer, tmp_path):
+ local_data = os.path.join(os.path.dirname(__file__), 'local_data')
+ dataset_uri = f'{local_data}/winograd_small.jsonl'
+ tokenizer = tiny_gpt2_tokenizer
+ seqlen = 2048
+ num_fewshot = 0
+ prompt_string = ''
+ seqlen = 2048
+ dl = InContextLearningSchemaTaskDataset(
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=num_fewshot,
+ fewshot_random_seed=1,
+ prompt_string=prompt_string,
+ example_delimiter='\n',
+ continuation_delimiter=' ### ',
+ destination_path=str(tmp_path / 'test_human_eval_small.jsonl'),
+ )
+ example = {'context_options': ['context one', 'context two'], 'gold': 0, 'continuation': 'this is a continuation'}
+ tokenized_example = dl.tokenize_example(prompt_and_fewshot='prompt ',
+ context_options=example['context_options'],
+ example=example)
+ assert all(tiny_gpt2_tokenizer.decode(cont) == ' this is a continuation' for cont in tokenized_example['answer'])
+ unpadded_inputs = [context[context != tokenizer.eos_token_id] for context in tokenized_example['context_options']]
+ untokenized_inputs = [tokenizer.decode(unpadded_input) for unpadded_input in unpadded_inputs]
+ assert untokenized_inputs == [
+ 'prompt context one this is a continuation', 'prompt context two this is a continuation'
+ ]
@pytest.mark.parametrize('dataset_uri', ['mmlu_small.jsonl'])
@@ -103,9 +856,9 @@ def test_mc_task_dataloader_subcategories(dataset_uri, tiny_gpt2_tokenizer, tmp_
batch_size = 8
seqlen = 64
dls = get_icl_task_dataloader('multiple_choice',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=2,
@@ -147,9 +900,9 @@ def test_lm_task_dataloader_extra_space(dataset_uri, tiny_gpt2_tokenizer, tmp_pa
batch_size = 2
seqlen = 64
dl = get_icl_task_dataloader('language_modeling',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=10,
@@ -188,9 +941,9 @@ def test_lm_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path):
batch_size = 2
seqlen = 64
dl = get_icl_task_dataloader('language_modeling',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=0,
@@ -226,9 +979,9 @@ def test_schema_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path):
batch_size = 2
seqlen = 64
dl = get_icl_task_dataloader('schema',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=1,
@@ -261,19 +1014,18 @@ def test_schema_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path):
@pytest.mark.parametrize('dataset_uri', ['winograd_small.jsonl'])
-def test_schema_task_dataloader_sentpiece_tokenizer(dataset_uri, tmp_path):
+def test_schema_task_dataloader_sentpiece_tokenizer(dataset_uri, tmp_path, tiny_llama_tokenizer):
pytest.importorskip('datasets')
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
-
- tokenizer = AutoTokenizer.from_pretrained('huggyllama/llama-7b', use_fast=False)
+ tokenizer = tiny_llama_tokenizer
dataset_uri = f'{local_data}/{dataset_uri}'
batch_size = 2
seqlen = 64
dl = get_icl_task_dataloader('schema',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=1,
@@ -318,9 +1070,9 @@ def test_lm_task_dataloader_opt_tokenizer(tiny_opt_tokenizer, dataset_uri, num_f
batch_size = 2
seqlen = 512
dl = get_icl_task_dataloader('language_modeling',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
@@ -360,9 +1112,9 @@ def test_mc_task_dataloader_opt_tokenizer(tiny_opt_tokenizer, dataset_uri, num_f
batch_size = 4
seqlen = 64
dl = get_icl_task_dataloader('multiple_choice',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
@@ -410,9 +1162,9 @@ def test_mc_split_batch(tiny_opt_tokenizer, dataset_uri, num_fewshot, tmp_path):
batch_size = 4
seqlen = 512
dl = get_icl_task_dataloader('multiple_choice',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
@@ -466,13 +1218,13 @@ def test_qa_split_batch(tiny_opt_tokenizer, dataset_uri, tmp_path):
tokenizer = tiny_opt_tokenizer
tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
- gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
+ gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) # for dist
dl = get_icl_task_dataloader(
- 'question_answering',
- dataset_uri,
- tokenizer,
- 8,
- max_seq_len=64,
+ icl_task_type='question_answering',
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=8,
+ max_seq_len=1024,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=0,
prompt_string='',
@@ -553,9 +1305,9 @@ def test_qa_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fews
# empirical number from the small test dataset
maximum_answer_length = 7
dl = get_icl_task_dataloader('question_answering',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
@@ -605,9 +1357,9 @@ def test_qa_task_with_cot_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path,
# empirical number from the small test dataset
maximum_answer_length = 132
dl = get_icl_task_dataloader('question_answering',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
@@ -658,9 +1410,9 @@ def test_mc_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path):
batch_size = 2
seqlen = 64
dl = get_icl_task_dataloader('multiple_choice',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=1,
@@ -697,16 +1449,18 @@ def test_code_eval_split_batch(dataset_uri, tmp_path):
pytest.importorskip('datasets')
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
dataset_uri = f'{local_data}/{dataset_uri}'
- tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
+ transformers = pytest.importorskip('transformers')
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ 'EleutherAI/gpt-neox-20b') # type: ignore reportUnboundVariable
tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
dl = get_icl_task_dataloader(
'code_evaluation',
- dataset_uri,
- tokenizer,
- 8,
- max_seq_len=64,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=8,
+ max_seq_len=1024,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=2,
prompt_string='',
@@ -738,7 +1492,6 @@ def test_code_eval_split_batch(dataset_uri, tmp_path):
'labels': str,
'prompts': str,
'tests': str,
- 'canonical_solutions': str,
'entry_points': str,
'test_inputs': list,
'test_outputs': list,
@@ -763,25 +1516,27 @@ def test_code_eval_split_batch(dataset_uri, tmp_path):
@pytest.mark.parametrize('num_fewshot', [0, 2])
@pytest.mark.parametrize('prompt_string', ['Please code:\n', ''])
@pytest.mark.parametrize('generations_per_sample', [1, 3])
-def test_code_eval_sentpiece_dataloader(dataset_uri, tmp_path, num_fewshot, prompt_string, generations_per_sample):
+def test_code_eval_sentpiece_dataloader(dataset_uri, tmp_path, num_fewshot, prompt_string, generations_per_sample,
+ tiny_llama_tokenizer):
pytest.importorskip('datasets')
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
- tokenizer = AutoTokenizer.from_pretrained('huggyllama/llama-7b')
+ tokenizer = tiny_llama_tokenizer
dataset_uri = f'{local_data}/{dataset_uri}'
batch_size = 4
seqlen = 2048
dl = get_icl_task_dataloader('code_evaluation',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
prompt_string=prompt_string,
example_delimiter='\n',
+ continuation_delimiter='',
question_prelimiter='Code start: \n',
destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'),
generations_per_sample=generations_per_sample)
@@ -828,25 +1583,26 @@ def test_code_eval_sentpiece_dataloader(dataset_uri, tmp_path, num_fewshot, prom
@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl'])
-def test_code_eval_test_cases(dataset_uri, tmp_path):
+def test_code_eval_test_cases(dataset_uri, tmp_path, tiny_llama_tokenizer):
pytest.importorskip('datasets')
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
- tokenizer = AutoTokenizer.from_pretrained('huggyllama/llama-7b')
+ tokenizer = tiny_llama_tokenizer
dataset_uri = f'{local_data}/{dataset_uri}'
batch_size = 4
seqlen = 512
dl = get_icl_task_dataloader('code_evaluation',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=0,
prompt_string='',
example_delimiter='\n',
+ continuation_delimiter='',
question_prelimiter='Code start: \n',
destination_path=str(tmp_path / f'icl_.jsonl'),
generations_per_sample=1)
@@ -866,9 +1622,8 @@ def test_code_eval_test_cases(dataset_uri, tmp_path):
assert any(item[0] != tokenizer.eos_token_id for item in batch['input_ids']) # longest should be pushed left
mod = types.ModuleType('test_module')
- for prompt, solution, inputs, outputs, entry_point in zip(batch['prompts'], batch['canonical_solutions'],
- batch['test_inputs'], batch['test_outputs'],
- batch['entry_points']):
+ for prompt, solution, inputs, outputs, entry_point in zip(batch['prompts'], batch['labels'], batch['test_inputs'],
+ batch['test_outputs'], batch['entry_points']):
exec(prompt + solution, mod.__dict__)
for test_input, test_output in zip(inputs, outputs):
result = mod.__dict__[entry_point](*eval(test_input))
@@ -876,26 +1631,27 @@ def test_code_eval_test_cases(dataset_uri, tmp_path):
@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl'])
-def test_code_eval_pass_at_k_validity(dataset_uri, tmp_path):
+def test_code_eval_pass_at_k_validity(dataset_uri, tmp_path, tiny_llama_tokenizer):
pytest.importorskip('datasets')
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
- tokenizer = AutoTokenizer.from_pretrained('huggyllama/llama-7b')
+ tokenizer = tiny_llama_tokenizer
dataset_uri = f'{local_data}/{dataset_uri}'
batch_size = 2
seqlen = 64
with pytest.raises(ValueError, match=r'.* pass_at_k .*'):
get_icl_task_dataloader('code_evaluation',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=0,
prompt_string='',
example_delimiter='\n',
+ continuation_delimiter='',
question_prelimiter='Code start: \n',
destination_path=str(tmp_path / f'icl_.jsonl'),
pass_at_k=10,
@@ -911,23 +1667,29 @@ def test_code_eval_task_dataloader(dataset_uri, tmp_path, num_fewshot, prompt_st
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
- tokenizer = AutoTokenizer.from_pretrained('mosaicml/mpt-7b')
+ transformers = pytest.importorskip('transformers')
+ tokenizer = transformers.AutoTokenizer.from_pretrained('mosaicml/mpt-7b') # type: ignore reportUnboundVariable
dataset_uri = f'{local_data}/{dataset_uri}'
batch_size = 4
seqlen = 2048
dl = get_icl_task_dataloader('code_evaluation',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
prompt_string=prompt_string,
example_delimiter='\n',
+ continuation_delimiter='',
question_prelimiter='Code start: \n',
destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'),
- generations_per_sample=generations_per_sample)
+ generations_per_sample=generations_per_sample,
+ generation_kwargs={
+ 'temperature': .9,
+ 'top_k': 40
+ })
assert isinstance(dl, DataSpec)
assert isinstance(dl.dataloader, DataLoader) # pyright
@@ -970,6 +1732,59 @@ def test_code_eval_task_dataloader(dataset_uri, tmp_path, num_fewshot, prompt_st
)
+@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl'])
+@pytest.mark.parametrize('num_fewshot', [0, 1])
+def test_eval_split_batch(tiny_opt_tokenizer, dataset_uri, num_fewshot, tmp_path):
+ pytest.importorskip('datasets')
+
+ local_data = os.path.join(os.path.dirname(__file__), 'local_data')
+ transformers = pytest.importorskip('transformers')
+ tokenizer = transformers.AutoTokenizer.from_pretrained('mosaicml/mpt-7b') # type: ignore reportUnboundVariable
+ dataset_uri = f'{local_data}/{dataset_uri}'
+ batch_size = 4
+ seqlen = 512
+
+ dl = get_icl_task_dataloader('code_evaluation',
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=num_fewshot,
+ prompt_string='',
+ example_delimiter='\n',
+ continuation_delimiter='',
+ question_prelimiter='Code start: \n',
+ destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'),
+ generations_per_sample=1,
+ generation_kwargs={
+ 'temperature': .9,
+ 'top_k': 40
+ })
+ assert isinstance(dl, DataSpec)
+ assert isinstance(dl.dataloader, DataLoader) # pyright
+ batch = next(dl.dataloader._get_iterator())
+ microbatch_size = 1
+ microbatches = dl.split_batch(batch, microbatch_size)
+ assert len(microbatches) == 4
+ for microbatch in microbatches:
+ assert dl.get_num_samples_in_batch(microbatch) == 1
+ assert 'input_ids' in microbatch
+ # TODO: what should this be?
+ # assert tuple(microbatch['input_ids'].shape) == (microbatch_size, seqlen)
+ assert 'attention_mask' in microbatch
+ # assert tuple(microbatch['attention_mask'].shape) == (microbatch_size, seqlen)
+ assert isinstance(microbatch['generation_kwargs'], dict)
+ assert microbatch['generation_kwargs']['temperature'] == .9
+ assert microbatch['generation_kwargs']['top_k'] == 40
+ assert microbatch['generation_kwargs']['pad_token_id'] == 0
+ assert microbatch['generation_kwargs']['num_beams'] == 1
+ assert microbatch['generation_kwargs']['num_return_sequences'] == 1
+ assert microbatch['generation_kwargs']['do_sample'] == True
+ assert microbatch['generation_kwargs']['use_cache'] == True
+ assert microbatch['generation_kwargs']['eos_token_id'] == 0
+
+
@pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl'])
@pytest.mark.parametrize('num_fewshot', [0, 5])
@device('gpu')
@@ -979,11 +1794,12 @@ def test_lm_task_evaluation(device, dataset_uri, num_fewshot, tiny_gpt2_tokenize
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
dataset_uri = f'{local_data}/{dataset_uri}'
tokenizer = tiny_gpt2_tokenizer
+ batch_size = 2
dl = get_icl_task_dataloader(
'language_modeling',
- dataset_uri,
- tokenizer,
- 2,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=2048,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
@@ -995,6 +1811,7 @@ def test_lm_task_evaluation(device, dataset_uri, num_fewshot, tiny_gpt2_tokenize
evaluator = Evaluator(label='lambada', dataloader=dl, metric_names=['InContextLearningLMAccuracy'])
+ transformers = pytest.importorskip('transformers')
config = transformers.AutoConfig.from_pretrained('EleutherAI/gpt-neo-125M')
model = transformers.AutoModelForCausalLM.from_config(config)
model = HuggingFaceModel(
@@ -1010,8 +1827,8 @@ def test_lm_task_evaluation(device, dataset_uri, num_fewshot, tiny_gpt2_tokenize
assert in_memory_logger.data['metrics/lambada/InContextLearningLMAccuracy'][0][1].item() == 0
-@pytest.mark.parametrize('dataset_uri', ['winograd_small.jsonl'])
@pytest.mark.parametrize('num_fewshot', [0, 5])
+@pytest.mark.parametrize('dataset_uri', ['winograd_small.jsonl'])
@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning')
def test_schema_task_evaluation(num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tmp_path, tiny_gpt2_model):
pytest.importorskip('datasets')
@@ -1019,12 +1836,13 @@ def test_schema_task_evaluation(num_fewshot, dataset_uri, tiny_gpt2_tokenizer, t
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
dataset_uri = f'{local_data}/{dataset_uri}'
tokenizer = tiny_gpt2_tokenizer
+ batch_size = 8
dl = get_icl_task_dataloader(
'schema',
- dataset_uri,
- tokenizer,
- 8,
- max_seq_len=64,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
+ max_seq_len=1024,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
prompt_string='',
@@ -1065,13 +1883,16 @@ def test_mc_task_evaluation_subcategories(device, world_size, dataset_uri, num_f
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
dataset_uri = f'{local_data}/{dataset_uri}'
tokenizer = tiny_gpt2_tokenizer
+ batch_size = 8
+ max_seq_len = 64
tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
+ reproducibility.seed_all(1234)
dls = get_icl_task_dataloader('multiple_choice',
- dataset_uri,
- tokenizer,
- 8,
- max_seq_len=64,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
+ max_seq_len=max_seq_len,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
prompt_string='',
@@ -1104,29 +1925,35 @@ def test_mc_task_evaluation_subcategories(device, world_size, dataset_uri, num_f
@pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl', 'hellaswag_small.jsonl'])
-@device('gpu')
@pytest.mark.parametrize('num_fewshot', [0, 5])
-def test_mc_task_evaluation(device, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tmp_path, tiny_gpt2_model):
+@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning')
+@device('gpu')
+@world_size(1, 2)
+def test_mc_task_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tmp_path,
+ tiny_gpt2_model):
pytest.importorskip('datasets')
in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
dataset_uri = f'{local_data}/{dataset_uri}'
tokenizer = tiny_gpt2_tokenizer
+ batch_size = 8
+ tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
+ gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
# seed because the fewshot selection is currently unseeded
reproducibility.seed_all(1234)
dl = get_icl_task_dataloader(
'multiple_choice',
- dataset_uri,
- tokenizer,
- 8,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=64,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
prompt_string='',
example_delimiter='\n',
continuation_delimiter=': ',
- destination_path=str(tmp_path / 'icl.jsonl'),
+ destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'),
)
evaluator = Evaluator(label='mc', dataloader=dl, metric_names=['InContextLearningMultipleChoiceAccuracy'])
@@ -1146,14 +1973,17 @@ def test_mc_task_evaluation(device, num_fewshot, dataset_uri, tiny_gpt2_tokenize
with open(dataset_uri) as f:
for _ in f:
num_samples += 1
- assert trainer.state.eval_metrics['mc']['InContextLearningMultipleChoiceAccuracy'].total == num_samples
+ total = trainer.state.eval_metrics['mc']['InContextLearningMultipleChoiceAccuracy'].total
+ dist.all_reduce(total) # type: ignore
+ assert total.item() == num_samples # type: ignore
+@pytest.mark.parametrize('num_fewshot', [0, 5])
@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl'])
+@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
+@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning')
@device('gpu')
@world_size(1, 2)
-@pytest.mark.parametrize('num_fewshot', [0, 5])
-@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_qa_task_evaluation_opt_tokenizer(device, world_size, tiny_opt_tokenizer, tiny_opt_model, num_fewshot,
dataset_uri, tmp_path):
pytest.importorskip('datasets')
@@ -1162,14 +1992,15 @@ def test_qa_task_evaluation_opt_tokenizer(device, world_size, tiny_opt_tokenizer
dataset_uri = f'{local_data}/{dataset_uri}'
tokenizer = tiny_opt_tokenizer
+ batch_size = 4
tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
dl = get_icl_task_dataloader(
'question_answering',
- dataset_uri,
- tokenizer,
- 2,
- max_seq_len=64,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
+ max_seq_len=1024,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
prompt_string='',
@@ -1193,11 +2024,12 @@ def test_qa_task_evaluation_opt_tokenizer(device, world_size, tiny_opt_tokenizer
assert in_memory_logger.data['metrics/triviaqa/InContextLearningQAAccuracy'][0][1].item() == 0
+@pytest.mark.parametrize('num_fewshot', [5])
@pytest.mark.parametrize('dataset_uri', ['gsm8k_small.jsonl'])
@device('gpu')
@world_size(1, 2)
-@pytest.mark.parametrize('num_fewshot', [5])
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
+@pytest.mark.filterwarnings(r'ignore:Cannot split .* of length.*:UserWarning')
def test_qa_task_evaluation_with_cot_opt_tokenizer(device, world_size, tiny_opt_tokenizer, tiny_opt_model, num_fewshot,
dataset_uri, tmp_path):
pytest.importorskip('datasets')
@@ -1206,14 +2038,15 @@ def test_qa_task_evaluation_with_cot_opt_tokenizer(device, world_size, tiny_opt_
dataset_uri = f'{local_data}/{dataset_uri}'
tokenizer = tiny_opt_tokenizer
+ batch_size = 4
tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
dl = get_icl_task_dataloader(
'question_answering',
- dataset_uri,
- tokenizer,
- 2,
- max_seq_len=256,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
+ max_seq_len=1024,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
prompt_string='',
@@ -1239,9 +2072,9 @@ def test_qa_task_evaluation_with_cot_opt_tokenizer(device, world_size, tiny_opt_
@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl'])
+@pytest.mark.parametrize('num_fewshot', [0, 5])
@device('gpu')
@world_size(1, 2)
-@pytest.mark.parametrize('num_fewshot', [0, 5])
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_qa_task_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tiny_gpt2_model,
tmp_path):
@@ -1250,14 +2083,15 @@ def test_qa_task_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_g
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
dataset_uri = f'{local_data}/{dataset_uri}'
tokenizer = tiny_gpt2_tokenizer
+ batch_size = 2
tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
dl = get_icl_task_dataloader(
'question_answering',
- dataset_uri,
- tokenizer,
- 2,
- max_seq_len=64,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
+ max_seq_len=1024,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
prompt_string='',
@@ -1283,10 +2117,10 @@ def test_qa_task_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_g
@pytest.mark.parametrize('dataset_uri', ['gsm8k_small.jsonl'])
-@device('gpu')
-@world_size(1, 2)
@pytest.mark.parametrize('num_fewshot', [5])
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
+@device('gpu')
+@world_size(1, 2)
def test_qa_task_with_cot_evaluation(device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tiny_gpt2_model,
tmp_path):
pytest.importorskip('datasets')
@@ -1294,14 +2128,15 @@ def test_qa_task_with_cot_evaluation(device, world_size, num_fewshot, dataset_ur
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
dataset_uri = f'{local_data}/{dataset_uri}'
tokenizer = tiny_gpt2_tokenizer
+ batch_size = 2
tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
dl = get_icl_task_dataloader(
'question_answering',
- dataset_uri,
- tokenizer,
- 2,
- max_seq_len=256,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
+ max_seq_len=1024,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
prompt_string='',
@@ -1340,10 +2175,10 @@ def test_code_eval_requires_valid_envvar(monkeypatch):
@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl'])
+@pytest.mark.parametrize('num_fewshot', [0])
+@pytest.mark.parametrize('generations_per_sample', range(1, 3))
@device('gpu')
@world_size(1, 2)
-@pytest.mark.parametrize('num_fewshot', [0])
-@pytest.mark.parametrize('generations_per_sample', [1, 2])
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_code_eval_microbatching(monkeypatch, device, world_size, tiny_opt_tokenizer, tiny_opt_model, num_fewshot,
dataset_uri, tmp_path, generations_per_sample):
@@ -1353,15 +2188,16 @@ def test_code_eval_microbatching(monkeypatch, device, world_size, tiny_opt_token
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
dataset_uri = f'{local_data}/{dataset_uri}'
tokenizer = tiny_opt_tokenizer
+ batch_size = 4
tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
dl = get_icl_task_dataloader(
'code_evaluation',
- dataset_uri,
- tokenizer,
- 2,
- max_seq_len=256,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
+ max_seq_len=150,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
prompt_string='',
@@ -1391,10 +2227,10 @@ def test_code_eval_microbatching(monkeypatch, device, world_size, tiny_opt_token
@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl'])
+@pytest.mark.parametrize('num_fewshot', [0])
+@pytest.mark.parametrize('generations_per_sample', range(1, 3))
@device('gpu')
@world_size(1, 2)
-@pytest.mark.parametrize('num_fewshot', [0])
-@pytest.mark.parametrize('generations_per_sample', [1, 2])
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_code_eval_sentpiece_evaluation(monkeypatch, device, world_size, num_fewshot, dataset_uri, tiny_t5_tokenizer,
tiny_t5_model, tmp_path, generations_per_sample):
@@ -1404,14 +2240,15 @@ def test_code_eval_sentpiece_evaluation(monkeypatch, device, world_size, num_few
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
dataset_uri = f'{local_data}/{dataset_uri}'
tokenizer = tiny_t5_tokenizer
+ batch_size = 2
tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
dl = get_icl_task_dataloader(
'code_evaluation',
- dataset_uri,
- tokenizer,
- 2,
- max_seq_len=256,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
+ max_seq_len=175,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
prompt_string='',
@@ -1438,11 +2275,11 @@ def test_code_eval_sentpiece_evaluation(monkeypatch, device, world_size, num_few
@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl'])
-@device('gpu')
-@world_size(1, 2)
@pytest.mark.parametrize('num_fewshot', [0, 2])
@pytest.mark.parametrize('generations_per_sample', [1])
@pytest.mark.filterwarnings(r'ignore: Input length of input_ids is')
+@device('gpu')
+@world_size(1, 2)
@pytest.mark.filterwarnings(r'ignore:.*The dataloader_len \(2\) is greater than the length.*:UserWarning')
def test_code_eval_task_evaluation(monkeypatch, device, world_size, num_fewshot, dataset_uri, tiny_gpt2_tokenizer,
tiny_gpt2_model, tmp_path, generations_per_sample):
@@ -1452,13 +2289,14 @@ def test_code_eval_task_evaluation(monkeypatch, device, world_size, num_fewshot,
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
dataset_uri = f'{local_data}/{dataset_uri}'
tokenizer = tiny_gpt2_tokenizer
+ batch_size = 2
tmp_path_to_broadcast = str(os.path.abspath(tmp_path))
gathered_paths = dist.all_gather_object(tmp_path_to_broadcast)
dl = get_icl_task_dataloader(
'code_evaluation',
- dataset_uri,
- tokenizer,
- 2,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=64 * num_fewshot,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
@@ -1496,9 +2334,9 @@ def test_lm_spacing_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path):
batch_size = 2
seqlen = 512
dl = get_icl_task_dataloader('language_modeling',
- dataset_uri,
- tokenizer,
- batch_size,
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=1,
@@ -1522,3 +2360,112 @@ def test_lm_spacing_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path):
assert first_batch_without_last_word.count(' UNIQUE ') == 1
assert second_batch_without_last_word.count(' UNIQUE ') == 1
+
+
+@pytest.mark.parametrize('dataset_uri', ['hf://mosaicml/test_dataset'])
+@pytest.mark.parametrize('num_fewshot', [0, 1])
+@pytest.mark.parametrize('prompt_string', ['Complete the voiceline: ', ''])
+@pytest.mark.parametrize('hf_loading_vars', [{
+ 'split': 'test',
+ 'name': 'juggernaut',
+}])
+@pytest.mark.parametrize('hf_parsing_map', [None, {'context': ['context'], 'continuation': ['continuation']}])
+@pytest.mark.filterwarnings(
+ r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning')
+def test_hf_dataloading_lm_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fewshot, prompt_string,
+ hf_loading_vars, hf_parsing_map):
+ pytest.importorskip('datasets')
+
+ tokenizer = tiny_gpt2_tokenizer
+ batch_size = 2
+ seqlen = 2048
+ dl = get_icl_task_dataloader('language_modeling',
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=0,
+ prompt_string='',
+ example_delimiter='\n',
+ continuation_delimiter=' ',
+ destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'),
+ hf_loading_vars=hf_loading_vars,
+ hf_parsing_map=hf_parsing_map)
+ assert isinstance(dl, DataSpec)
+ assert isinstance(dl.dataloader, DataLoader) # pyright
+ batch = next(dl.dataloader._get_iterator())
+
+ assert 'input_ids' in batch
+ assert tuple(batch['input_ids'].shape) == (batch_size, seqlen)
+ assert 'attention_mask' in batch
+ assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen)
+ assert 'continuation_indices' in batch
+ assert isinstance(batch['continuation_indices'], list) and len(batch['continuation_indices']) == batch_size
+ assert 'mode' in batch
+ assert batch['mode'] == 'icl_task'
+ min_idx = min(batch['continuation_indices'][0]).item()
+ max_idx = max(batch['continuation_indices'][0]).item()
+ assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + 1]) == ' and me.'
+
+ decoded_batch = [tokenizer.decode(row[row != tokenizer.eos_token_id]) for row in batch['input_ids']]
+ assert decoded_batch[0] == "Looks like it's just you and me."
+ assert decoded_batch[1] == "There's a fine line between bravery and stupidity."
+
+
+@pytest.mark.parametrize('dataset_uri', ['hf://mosaicml/test_dataset'])
+@pytest.mark.parametrize('num_fewshot', [0, 1])
+@pytest.mark.parametrize('prompt_string', ['What spell does this invoke? ', ''])
+@pytest.mark.parametrize('hf_loading_vars', [{
+ 'split': 'test',
+ 'name': 'invoker',
+}])
+@pytest.mark.parametrize('hf_parsing_map', [{'context': ['quas', 'wex', 'exort'], 'answer': ['spell']}])
+@pytest.mark.filterwarnings(
+ r'ignore:The repository for mosaicml/test_dataset contains custom code which must*:FutureWarning')
+def test_hf_dataloading_custom_parsing(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fewshot, prompt_string,
+ hf_loading_vars, hf_parsing_map):
+ pytest.importorskip('datasets')
+
+ tokenizer = tiny_gpt2_tokenizer
+ batch_size = 2
+ seqlen = 2048
+
+ # empirical number from the small test dataset
+ maximum_answer_length = 4
+
+ dl = get_icl_task_dataloader('question_answering',
+ dataset_uri=dataset_uri,
+ tokenizer=tokenizer,
+ batch_size=batch_size,
+ max_seq_len=seqlen,
+ pad_tok_id=tokenizer.eos_token_id,
+ num_fewshot=num_fewshot,
+ prompt_string=prompt_string,
+ example_delimiter='\n',
+ question_prelimiter='Orbs: ',
+ continuation_delimiter='\nSpell:',
+ destination_path=str(tmp_path / 'test_dataset_lm_juggernaut.jsonl'),
+ hf_loading_vars=hf_loading_vars,
+ hf_parsing_map=hf_parsing_map)
+ assert isinstance(dl, DataSpec)
+ assert isinstance(dl.dataloader, DataLoader) # pyright
+ batch = next(dl.dataloader._get_iterator())
+
+ assert tuple(batch['input_ids'].shape) == (batch_size, seqlen - maximum_answer_length)
+ assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - maximum_answer_length)
+ assert batch['mode'] == 'generate'
+ # the maximum generation length from the small test data
+ assert batch['generation_length'] == maximum_answer_length
+ assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids'])
+
+ decoded_batch = tokenizer.batch_decode(batch['input_ids'])
+ assert all(item.count('Orbs: ') == num_fewshot + 1 for item in decoded_batch)
+ assert all(item.count('\nSpell:') == num_fewshot + 1 for item in decoded_batch)
+
+ if len(prompt_string) > 0:
+ assert all(item.count('What spell does this invoke? ') == 1 for item in decoded_batch)
+ assert all(
+ set(found) == set(expected) for found, expected in zip(batch['labels'], [['defeaning blast'], ['cold snap']]))
+ assert decoded_batch[0].endswith('Orbs: quas wex exort\nSpell:')
+ assert decoded_batch[1].endswith('Orbs: quas quas quas\nSpell:')
diff --git a/tests/datasets/test_mnist.py b/tests/datasets/test_mnist.py
deleted file mode 100644
index 7342184d03..0000000000
--- a/tests/datasets/test_mnist.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-import pytest
-
-from composer.datasets import build_mnist_dataloader, build_synthetic_mnist_dataloader
-
-
-@pytest.mark.parametrize('is_train', [False, True])
-@pytest.mark.parametrize('synthetic', [pytest.param(False, marks=pytest.mark.daily), True])
-def test_mnist_shape_length(is_train, synthetic):
- batch_size = 1
-
- if synthetic:
- loader = build_synthetic_mnist_dataloader(global_batch_size=batch_size, is_train=is_train)
- else:
- loader = build_mnist_dataloader(datadir='/tmp', global_batch_size=batch_size, is_train=is_train)
-
- samples = list(loader)
- if is_train:
- assert len(samples) == 60000 // batch_size
- else:
- assert len(samples) == 10000 // batch_size
-
- assert samples[0][0].shape == (1, 1, 28, 28)
diff --git a/tests/datasets/test_segmentation_transforms.py b/tests/datasets/test_segmentation_transforms.py
deleted file mode 100644
index 2e4af40126..0000000000
--- a/tests/datasets/test_segmentation_transforms.py
+++ /dev/null
@@ -1,84 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-import numpy as np
-import pytest
-from PIL import Image
-
-from composer.datasets.ade20k import (PadToSize, PhotometricDistoration, RandomCropPair, RandomHFlipPair,
- RandomResizePair)
-
-
-@pytest.fixture
-def size():
- return 16, 16
-
-
-@pytest.fixture
-def sample_pair(size):
- img = Image.new(mode='RGB', size=size)
- target = Image.new(mode='L', size=size)
- return img, target
-
-
-def test_random_resize(sample_pair, size):
- random_resize_transform = RandomResizePair(min_scale=0.5, max_scale=2.0, base_size=size)
-
- # Test that the resized image remains within bounds for 10 iterations
- for _ in range(10):
- resized_img, resized_target = random_resize_transform(sample_pair)
- assert resized_img.size == resized_target.size
- assert resized_img.size[0] >= size[0] // 2 and resized_img.size[0] <= size[0] * 2
- assert resized_img.size[1] >= size[1] // 2 and resized_img.size[1] <= size[1] * 2
-
-
-@pytest.mark.parametrize('crop_size', [(8, 8), (32, 32)])
-def test_random_crop(sample_pair, crop_size):
- random_crop_transform = RandomCropPair(crop_size)
- image, target = random_crop_transform(sample_pair)
- assert image.size == target.size
- final_size = min(crop_size[0], sample_pair[0].height), min(crop_size[1], sample_pair[0].width)
- assert final_size == image.size
-
-
-def test_random_hflip(sample_pair):
- old_image, old_target = np.array(sample_pair[0]), np.array(sample_pair[1])
-
- # Always flip
- always_hflip_transform = RandomHFlipPair(probability=1.0)
- new_image, new_target = always_hflip_transform(sample_pair)
- new_image, new_target = np.array(new_image), np.array(new_target)
- assert np.allclose(new_image, old_image[:, ::-1]) and np.allclose(new_target, old_target[:, ::-1])
-
- # Never flip
- always_hflip_transform = RandomHFlipPair(probability=0.0)
- new_image, new_target = always_hflip_transform(sample_pair)
- new_image, new_target = np.array(new_image), np.array(new_target)
- assert np.allclose(new_image, old_image) and np.allclose(new_target, old_target)
-
-
-@pytest.mark.parametrize('pad_size', [(32, 32), (8, 8)])
-def test_pad_transform(sample_pair, pad_size):
- image = sample_pair[0]
- pad_transform = PadToSize(size=pad_size, fill=255)
- padded_image = pad_transform(image)
- final_size = max(pad_size[1], image.width), max(pad_size[0], image.height)
- # Check for correct size and number of padding elements
- assert padded_image.size == final_size
-
- # Check appropriate amount of padding is used
- padded_image = np.array(padded_image)
- initial_area = image.width * image.height
- final_area = final_size[0] * final_size[1]
- n_channels = padded_image.shape[2]
- pad_volume = n_channels * (final_area - initial_area)
- assert pad_volume == (padded_image == 255).sum()
-
-
-def test_photometric_distortion(sample_pair):
- old_image = sample_pair[0]
- # Test no transform case
- photometric_transform = PhotometricDistoration(brightness=1.0, contrast=1.0, saturation=1.0, hue=0)
- new_image = photometric_transform(old_image)
- old_image, new_image = np.array(old_image), np.array(new_image)
- assert np.allclose(old_image, new_image)
diff --git a/tests/datasets/test_synthetic_data.py b/tests/datasets/test_synthetic_data.py
deleted file mode 100644
index 6f62aebb9d..0000000000
--- a/tests/datasets/test_synthetic_data.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-from typing import Optional
-
-import pytest
-import torch
-
-from composer.datasets.synthetic import (SyntheticBatchPairDataset, SyntheticDataLabelType, SyntheticDataType,
- SyntheticPILDataset)
-
-
-@pytest.mark.parametrize('data_type', [
- SyntheticDataType.GAUSSIAN,
- SyntheticDataType.SEPARABLE,
-])
-@pytest.mark.parametrize('label_type', [
- SyntheticDataLabelType.CLASSIFICATION_ONE_HOT,
- SyntheticDataLabelType.CLASSIFICATION_INT,
-])
-def test_synthetic_batch_pair_creation(data_type: SyntheticDataType, label_type: SyntheticDataLabelType):
- if data_type == SyntheticDataType.SEPARABLE:
- if label_type != SyntheticDataLabelType.CLASSIFICATION_INT:
- pytest.skip('Separable data requires classification int labels')
- num_classes = 2
- label_shape = None
- else:
- num_classes = 10
- label_shape = (1, 10, 12)
-
- if data_type == SyntheticDataType.GAUSSIAN and label_type == SyntheticDataLabelType.CLASSIFICATION_INT:
- pytest.xfail('classification_int is not currently supported with gaussian data')
-
- dataset_size = 1000
- data_shape = (3, 32, 32)
- num_samples_to_create = 10
- dataset = SyntheticBatchPairDataset(total_dataset_size=dataset_size,
- data_shape=data_shape,
- num_unique_samples_to_create=num_samples_to_create,
- data_type=data_type,
- label_type=label_type,
- num_classes=num_classes,
- label_shape=label_shape)
- assert len(dataset) == dataset_size
-
- # verify datapoints are correct
- x, y = dataset[0]
- assert x.size() == data_shape
- if label_type == SyntheticDataLabelType.CLASSIFICATION_INT:
- assert isinstance(y.item(), int)
- elif label_type == SyntheticDataLabelType.CLASSIFICATION_ONE_HOT:
- assert y.size() == (num_classes,)
- assert torch.min(y) == 0
- assert torch.max(y) == 1
-
- # check that points were allocated in memory after the first call to __getitem__
- assert dataset.input_data is not None
- assert dataset.input_target is not None
- # check that the correct number of points were allocated in memory
- assert dataset.input_data.size()[0] == num_samples_to_create
- assert dataset.input_target.size()[0] == num_samples_to_create
-
- # verify that you can getch points outside the num_samples_to_create range
- # (still within the total dataset size range)
- x, y = dataset[num_samples_to_create + 1]
- assert x is not None
- assert y is not None
-
-
-@pytest.mark.parametrize('label_type', [
- SyntheticDataLabelType.CLASSIFICATION_ONE_HOT,
- SyntheticDataLabelType.CLASSIFICATION_INT,
-])
-@pytest.mark.parametrize('num_classes', [None, 0])
-def test_synthetic_classification_param_validation(label_type: SyntheticDataLabelType, num_classes: Optional[int]):
- with pytest.raises(ValueError):
- SyntheticBatchPairDataset(total_dataset_size=10,
- data_shape=(2, 2),
- label_type=label_type,
- num_classes=num_classes)
-
-
-@pytest.mark.parametrize('data_type', [
- SyntheticDataType.GAUSSIAN,
- SyntheticDataType.SEPARABLE,
-])
-@pytest.mark.parametrize('label_type', [
- SyntheticDataLabelType.CLASSIFICATION_ONE_HOT,
- SyntheticDataLabelType.CLASSIFICATION_INT,
-])
-def test_synthetic_image_data_creation(data_type: SyntheticDataType, label_type: SyntheticDataLabelType):
- if data_type == SyntheticDataType.SEPARABLE:
- if label_type != SyntheticDataLabelType.CLASSIFICATION_INT:
- pytest.skip('Seperable data requires classification int labels')
- num_classes = 2
- label_shape = None
- else:
- num_classes = 10
- label_shape = (1, 10, 12)
-
- if data_type == SyntheticDataType.GAUSSIAN and label_type == SyntheticDataLabelType.CLASSIFICATION_INT:
- pytest.xfail('classification_int is not currently supported with gaussian data')
-
- dataset_size = 1000
- data_shape = (32, 32)
- num_samples_to_create = 100
- dataset = SyntheticPILDataset(total_dataset_size=dataset_size,
- data_shape=data_shape,
- num_unique_samples_to_create=num_samples_to_create,
- data_type=data_type,
- label_type=label_type,
- num_classes=num_classes,
- label_shape=label_shape)
- assert len(dataset) == dataset_size
-
- # verify datapoints are correct
- x, y = dataset[0]
- assert x.size == data_shape
- if label_type == SyntheticDataLabelType.CLASSIFICATION_INT:
- assert isinstance(y.item(), int)
- elif label_type == SyntheticDataLabelType.CLASSIFICATION_ONE_HOT:
- assert y.size() == (num_classes,)
- assert torch.min(y) == 0
- assert torch.max(y) == 1
-
- # check that points were allocated in memory after the first call to __getitem__
- assert dataset._dataset.input_data is not None
- assert dataset._dataset.input_target is not None
- # check that the correct number of points were allocated in memory
- assert dataset._dataset.input_data.shape[0] == num_samples_to_create
- assert dataset._dataset.input_target.shape[0] == num_samples_to_create
-
- # verify that you can getch points outside the num_samples_to_create range
- # (still within the total dataset size range)
- x, y = dataset[num_samples_to_create + 1]
- assert x is not None
- assert y is not None
diff --git a/tests/fixtures/fixtures.py b/tests/fixtures/fixtures.py
index cfd8674338..5ab9b472b0 100644
--- a/tests/fixtures/fixtures.py
+++ b/tests/fixtures/fixtures.py
@@ -244,11 +244,23 @@ def tiny_gpt2_tokenizer_helper():
return hf_tokenizer
+def tiny_llama_tokenizer_helper():
+ transformers = pytest.importorskip('transformers')
+
+ hf_tokenizer = transformers.AutoTokenizer.from_pretrained('huggyllama/llama-7b', use_fast=False)
+ return hf_tokenizer
+
+
@pytest.fixture(scope='session')
def _session_tiny_gpt2_tokenizer(): # type: ignore
return tiny_gpt2_tokenizer_helper()
+@pytest.fixture(scope='session')
+def _session_tiny_llama_tokenizer(): # type: ignore
+ return tiny_llama_tokenizer_helper()
+
+
def tiny_opt_model_helper(config):
transformers = pytest.importorskip('transformers')
@@ -320,6 +332,47 @@ def _session_tiny_t5_model(_session_tiny_t5_config): # type: ignore
return tiny_t5_model_helper(_session_tiny_t5_config)
+def tiny_mistral_config_helper():
+ transformers = pytest.importorskip('transformers')
+
+ tiny_overrides = {
+ 'hidden_size': 128,
+ 'intermediate_size': 256,
+ 'num_attention_heads': 8,
+ 'num_hidden_layers': 2,
+ 'num_kv_heads': 4
+ }
+ return transformers.AutoConfig.from_pretrained('mistralai/Mistral-7B-v0.1', **tiny_overrides)
+
+
+@pytest.fixture(scope='session')
+def _session_tiny_mistral_config(): # type: ignore
+ return tiny_mistral_config_helper()
+
+
+def tiny_mistral_tokenizer_helper():
+ transformers = pytest.importorskip('transformers')
+
+ hf_tokenizer = transformers.AutoTokenizer.from_pretrained('mistralai/Mistral-7B-v0.1', model_max_length=512)
+ return hf_tokenizer
+
+
+@pytest.fixture(scope='session')
+def _session_tiny_mistral_tokenizer(): # type: ignore
+ return tiny_mistral_tokenizer_helper()
+
+
+def tiny_mistral_model_helper(config):
+ transformers = pytest.importorskip('transformers')
+
+ return transformers.AutoModelForCausalLM.from_config(config)
+
+
+@pytest.fixture(scope='session')
+def _session_tiny_mistral_model(_session_tiny_mistral_config): # type: ignore
+ return tiny_mistral_model_helper(_session_tiny_mistral_config)
+
+
@pytest.fixture
def tiny_bert_model(_session_tiny_bert_model):
return copy.deepcopy(_session_tiny_bert_model)
@@ -360,6 +413,11 @@ def tiny_gpt2_tokenizer(_session_tiny_gpt2_tokenizer):
return copy.deepcopy(_session_tiny_gpt2_tokenizer)
+@pytest.fixture
+def tiny_llama_tokenizer(_session_tiny_llama_tokenizer):
+ return copy.deepcopy(_session_tiny_llama_tokenizer)
+
+
@pytest.fixture
def tiny_gpt2_model(_session_tiny_gpt2_model):
return copy.deepcopy(_session_tiny_gpt2_model)
@@ -393,3 +451,18 @@ def tiny_t5_tokenizer(_session_tiny_t5_tokenizer):
@pytest.fixture
def tiny_t5_model(_session_tiny_t5_model):
return copy.deepcopy(_session_tiny_t5_model)
+
+
+@pytest.fixture
+def tiny_mistral_config(_session_tiny_mistral_config):
+ return copy.deepcopy(_session_tiny_mistral_config)
+
+
+@pytest.fixture
+def tiny_mistral_tokenizer(_session_tiny_mistral_tokenizer):
+ return copy.deepcopy(_session_tiny_mistral_tokenizer)
+
+
+@pytest.fixture
+def tiny_mistral_model(_session_tiny_mistral_model):
+ return copy.deepcopy(_session_tiny_mistral_model)
diff --git a/tests/loggers/test_mlflow_logger.py b/tests/loggers/test_mlflow_logger.py
index 5ff0a2fa3c..d5de5b8171 100644
--- a/tests/loggers/test_mlflow_logger.py
+++ b/tests/loggers/test_mlflow_logger.py
@@ -29,7 +29,9 @@ def _get_latest_mlflow_run(experiment_name, tracking_uri=None):
# NB: Convert tracking URI to string because MlflowClient doesn't support non-string
# (e.g. PosixPath) tracking URI representations
client = MlflowClient(str(tracking_uri))
- experiment_id = (client.get_experiment_by_name(experiment_name).experiment_id)
+ experiment = client.get_experiment_by_name(experiment_name)
+ assert experiment is not None
+ experiment_id = experiment.experiment_id
first_run_or_empty = client.search_runs(
experiment_ids=[experiment_id],
max_results=1,
@@ -164,6 +166,26 @@ def test_mlflow_experiment_init_experiment_name(monkeypatch):
id_logger.post_close()
+def test_mlflow_experiment_init_existing_composer_run(monkeypatch):
+ """ Test that an existing MLFlow run is used if one already exists in the experiment for the Composer run.
+ """
+ mlflow = pytest.importorskip('mlflow')
+
+ monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock())
+ monkeypatch.setattr(mlflow, 'start_run', MagicMock())
+
+ mock_state = MagicMock()
+ mock_state.run_name = 'dummy-run-name'
+
+ existing_id = 'dummy-id'
+ mock_search_runs = MagicMock(return_value=[MagicMock(info=MagicMock(run_id=existing_id))])
+ monkeypatch.setattr(mlflow, 'search_runs', mock_search_runs)
+
+ test_logger = MLFlowLogger()
+ test_logger.init(state=mock_state, logger=MagicMock())
+ assert test_logger._run_id == existing_id
+
+
def test_mlflow_experiment_set_up(tmp_path):
""" Test that MLFlow experiment is set up correctly within mlflow
"""
@@ -189,6 +211,7 @@ def test_mlflow_experiment_set_up(tmp_path):
)
run_id = run.info.run_id
experiment_id = run.info.experiment_id
+ tags = run.data.tags
# Check uri set correctly.
assert mlflow_uri.exists()
@@ -207,6 +230,9 @@ def test_mlflow_experiment_set_up(tmp_path):
actual_run_name = run_cfg['run_name']
assert actual_run_name == expected_run_name
+ # Check run tagged with Composer run name.
+ assert tags['composer_run_name'] == mock_state.run_name
+
# Check run ended.
test_mlflow_logger.post_close()
assert mlflow.active_run() is None
@@ -336,6 +362,48 @@ def test_mlflow_save_model(tmp_path, tiny_gpt2_model, tiny_gpt2_tokenizer):
check_hf_tokenizer_equivalence(loaded_model['tokenizer'], tiny_gpt2_tokenizer)
+@pytest.mark.filterwarnings('ignore:.*Setuptools is replacing distutils.*:UserWarning')
+@pytest.mark.filterwarnings("ignore:.*The 'transformers' MLflow Models integration.*:FutureWarning")
+def test_mlflow_save_peft_model(tmp_path, tiny_mistral_model, tiny_mistral_tokenizer):
+ mlflow = pytest.importorskip('mlflow')
+ peft = pytest.importorskip('peft')
+
+ # Reload just so the model has the update base model name
+ tiny_mistral_model.save_pretrained(tmp_path / Path('tiny_mistral_save_pt'))
+ tiny_mistral_model = tiny_mistral_model.from_pretrained(tmp_path / Path('tiny_mistral_save_pt'))
+
+ peft_config = {'peft_type': 'LORA'}
+ peft_model = peft.get_peft_model(tiny_mistral_model, peft.get_peft_config(peft_config))
+
+ mlflow_uri = tmp_path / Path('my-test-mlflow-uri')
+ mlflow_exp_name = 'test-log-model-exp-name'
+ test_mlflow_logger = MLFlowLogger(
+ tracking_uri=mlflow_uri,
+ experiment_name=mlflow_exp_name,
+ )
+
+ mock_state = MagicMock()
+ mock_state.run_name = 'dummy-run-name' # this run name should be unused.
+ mock_logger = MagicMock()
+
+ peft_model.save_pretrained(tmp_path / Path('peft_model_save_pt'))
+ tiny_mistral_tokenizer.save_pretrained(tmp_path / Path('peft_model_save_pt'))
+
+ local_mlflow_save_path = str(tmp_path / Path('my_model_local'))
+ test_mlflow_logger.init(state=mock_state, logger=mock_logger)
+ test_mlflow_logger.save_model(
+ flavor='peft',
+ path=local_mlflow_save_path,
+ save_pretrained_dir=str(tmp_path / Path('peft_model_save_pt')),
+ )
+ test_mlflow_logger.post_close()
+
+ loaded_model = mlflow.pyfunc.load_model(local_mlflow_save_path).unwrap_python_model()
+
+ check_hf_model_equivalence(loaded_model.model, tiny_mistral_model)
+ check_hf_tokenizer_equivalence(loaded_model.tokenizer, tiny_mistral_tokenizer)
+
+
@pytest.mark.filterwarnings('ignore:.*Setuptools is replacing distutils.*:UserWarning')
@pytest.mark.filterwarnings("ignore:.*The 'transformers' MLflow Models integration.*:FutureWarning")
def test_mlflow_register_model(tmp_path, monkeypatch):
@@ -364,11 +432,54 @@ def test_mlflow_register_model(tmp_path, monkeypatch):
name='my_model',
)
- assert mlflow.register_model.called_with(model_uri=local_mlflow_save_path,
- name='my_catalog.my_schema.my_model',
- await_registration_for=300,
- tags=None,
- registry_uri='databricks-uc')
+ mlflow.register_model.assert_called_with(
+ model_uri=local_mlflow_save_path,
+ name='my_catalog.my_schema.my_model',
+ await_registration_for=300,
+ tags=None,
+ )
+ assert mlflow.get_registry_uri() == 'databricks-uc'
+
+ test_mlflow_logger.post_close()
+
+
+@pytest.mark.filterwarnings('ignore:.*Setuptools is replacing distutils.*:UserWarning')
+@pytest.mark.filterwarnings("ignore:.*The 'transformers' MLflow Models integration.*:FutureWarning")
+def test_mlflow_register_model_with_run_id(tmp_path, monkeypatch):
+ mlflow = pytest.importorskip('mlflow')
+
+ mlflow_uri = tmp_path / Path('my-test-mlflow-uri')
+ mlflow_exp_name = 'test-log-model-exp-name'
+ test_mlflow_logger = MLFlowLogger(
+ tracking_uri=mlflow_uri,
+ experiment_name=mlflow_exp_name,
+ model_registry_prefix='my_catalog.my_schema',
+ model_registry_uri='databricks-uc',
+ )
+
+ monkeypatch.setattr(test_mlflow_logger._mlflow_client, 'create_model_version', MagicMock())
+ monkeypatch.setattr(test_mlflow_logger._mlflow_client, 'create_registered_model',
+ MagicMock(return_value=type('MockResponse', (), {'name': 'my_catalog.my_schema.my_model'})))
+
+ mock_state = MagicMock()
+ mock_state.run_name = 'dummy-run-name' # this run name should be unused.
+ mock_logger = MagicMock()
+
+ local_mlflow_save_path = str(tmp_path / Path('my_model_local'))
+ test_mlflow_logger.init(state=mock_state, logger=mock_logger)
+
+ test_mlflow_logger.register_model_with_run_id(
+ model_uri=local_mlflow_save_path,
+ name='my_model',
+ )
+
+ test_mlflow_logger._mlflow_client.create_model_version.assert_called_with(
+ name='my_catalog.my_schema.my_model',
+ source=local_mlflow_save_path,
+ run_id=test_mlflow_logger._run_id,
+ await_creation_for=300,
+ tags=None,
+ )
assert mlflow.get_registry_uri() == 'databricks-uc'
test_mlflow_logger.post_close()
@@ -488,7 +599,8 @@ def test_mlflow_logging_works(tmp_path, device):
actual_params_list = [param_filepath.stem for param_filepath in param_path.iterdir()]
expected_params_list = [
- 'num_cpus_per_node', 'node_name', 'num_nodes', 'rank_zero_seed', 'composer_version', 'composer_commit_hash'
+ 'num_cpus_per_node', 'node_name', 'num_nodes', 'rank_zero_seed', 'composer_version', 'composer_commit_hash',
+ 'mlflow_experiment_id', 'mlflow_run_id'
]
assert set(expected_params_list) == set(actual_params_list)
@@ -549,3 +661,96 @@ def before_forward(self, state: State, logger: Logger):
run_file_path = mlflow_uri / Path(experiment_id) / Path(run_id)
im_dir = run_file_path / Path('artifacts')
assert len(os.listdir(im_dir)) == expected_num_ims
+
+
+@device('cpu')
+def test_mlflow_ignore_metrics(tmp_path, device):
+ mlflow_uri = tmp_path / Path('my-test-mlflow-uri')
+ experiment_name = 'mlflow_logging_test'
+ test_mlflow_logger = MLFlowLogger(
+ tracking_uri=mlflow_uri,
+ experiment_name=experiment_name,
+ log_system_metrics=False,
+ ignore_metrics=['metrics/eval/*', 'nothing/should/match', 'metrics/train/CrossEntropy'],
+ )
+
+ dataset_size = 64
+ batch_size = 4
+ num_batches = 4
+ eval_interval = '1ba'
+
+ trainer = Trainer(model=SimpleConvModel(),
+ loggers=test_mlflow_logger,
+ train_dataloader=DataLoader(RandomImageDataset(size=dataset_size), batch_size),
+ eval_dataloader=DataLoader(RandomImageDataset(size=dataset_size), batch_size),
+ max_duration=f'{num_batches}ba',
+ eval_interval=eval_interval,
+ device=device)
+ trainer.fit()
+ # Allow async logging to finish.
+ time.sleep(3)
+ test_mlflow_logger.post_close()
+
+ run = _get_latest_mlflow_run(
+ experiment_name=experiment_name,
+ tracking_uri=mlflow_uri,
+ )
+ run_id = run.info.run_id
+ experiment_id = run.info.experiment_id
+
+ run_file_path = mlflow_uri / Path(experiment_id) / Path(run_id)
+
+ # Test metrics logged.
+ for metric_name in [
+ 'metrics/train/MulticlassAccuracy',
+ 'loss/train/total',
+ ]:
+ metric_file = run_file_path / Path('metrics') / Path(metric_name)
+ with open(metric_file) as f:
+ csv_reader = csv.reader(f, delimiter=' ')
+ lines = list(csv_reader)
+
+ assert len(lines) == num_batches
+
+ # Test metrics are not logged.
+ for metric_name in ['metrics/eval/MulticlassAccuracy', 'metrics/eval/CrossEntropy', 'metrics/train/CrossEntropy']:
+ metric_file = run_file_path / Path('metrics') / Path(metric_name)
+ assert not os.path.exists(metric_file)
+
+ # Test system metrics are not logged.
+ metric_file = run_file_path / Path('metrics') / Path('system/cpu_utilization_percentage')
+ assert not os.path.exists(metric_file)
+
+
+def test_mlflow_ignore_hyperparameters(tmp_path):
+ mlflow_uri = tmp_path / Path('my-test-mlflow-uri')
+ experiment_name = 'mlflow_logging_test'
+ test_mlflow_logger = MLFlowLogger(tracking_uri=mlflow_uri,
+ experiment_name=experiment_name,
+ log_system_metrics=False,
+ ignore_hyperparameters=['num*', 'mlflow_run_id', 'nothing'])
+
+ Trainer(model=SimpleConvModel(), loggers=test_mlflow_logger, max_duration=f'4ba')
+ # Allow async logging to finish.
+ time.sleep(3)
+ test_mlflow_logger.post_close()
+
+ run = _get_latest_mlflow_run(
+ experiment_name=experiment_name,
+ tracking_uri=mlflow_uri,
+ )
+ run_file_path = mlflow_uri / Path(run.info.experiment_id) / Path(run.info.run_id)
+
+ # Test params logged.
+ param_path = run_file_path / Path('params')
+ actual_params_list = [param_filepath.stem for param_filepath in param_path.iterdir()]
+
+ # should not see num_cpus_per_node, num_nodes, mlflow_run_id
+ expected_params_list = [
+ 'node_name',
+ 'rank_zero_seed',
+ 'composer_version',
+ 'composer_commit_hash',
+ 'mlflow_experiment_id',
+ ]
+ assert set(expected_params_list) == set(actual_params_list)
diff --git a/tests/loggers/test_mosaicml_logger.py b/tests/loggers/test_mosaicml_logger.py
index 106acfc6fc..0834e3dbf0 100644
--- a/tests/loggers/test_mosaicml_logger.py
+++ b/tests/loggers/test_mosaicml_logger.py
@@ -59,10 +59,11 @@ def test_format_data_to_json_serializable():
'key3': 3.14,
'key4': True,
'key5': torch.tensor([1, 2, 3]),
- 'key6': {
+ 'key6': torch.tensor([42]),
+ 'key7': {
'inner_key': 'inner_value'
},
- 'key7': [1, 2, 3],
+ 'key8': [1, 2, 3],
}
formatted_data = format_data_to_json_serializable(data)
@@ -72,10 +73,11 @@ def test_format_data_to_json_serializable():
'key3': 3.14,
'key4': True,
'key5': 'Tensor of shape torch.Size([3])',
- 'key6': {
+ 'key6': 42,
+ 'key7': {
'inner_key': 'inner_value'
},
- 'key7': [1, 2, 3],
+ 'key8': [1, 2, 3],
}
assert formatted_data == expected_formatted_data
@@ -83,6 +85,7 @@ def test_format_data_to_json_serializable():
@pytest.mark.parametrize('callback_cls', get_cbs_and_marks(callbacks=True))
@world_size(1, 2)
+@pytest.mark.filterwarnings('ignore::UserWarning')
def test_logged_data_is_json_serializable(monkeypatch, callback_cls: Type[Callback], world_size):
"""Test that all logged data is json serializable, which is a requirement to use MAPI."""
diff --git a/tests/loggers/test_neptune_logger.py b/tests/loggers/test_neptune_logger.py
new file mode 100644
index 0000000000..4463595c0f
--- /dev/null
+++ b/tests/loggers/test_neptune_logger.py
@@ -0,0 +1,149 @@
+# Copyright 2022 MosaicML Composer authors
+# SPDX-License-Identifier: Apache-2.0
+import os
+import uuid
+from pathlib import Path
+from typing import Sequence
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+from torch.utils.data import DataLoader
+
+from composer import Trainer
+from composer._version import __version__
+from composer.loggers import NeptuneLogger
+from composer.utils import dist
+from tests.common import RandomImageDataset, SimpleConvModel
+from tests.common.markers import device
+
+
+@pytest.fixture
+def test_neptune_logger() -> NeptuneLogger:
+ neptune_project = 'test_project'
+ neptune_api_token = 'test_token'
+
+ neptune_logger = NeptuneLogger(
+ project=neptune_project,
+ api_token=neptune_api_token,
+ rank_zero_only=False,
+ mode='debug',
+ upload_artifacts=True,
+ )
+
+ return neptune_logger
+
+
+def test_neptune_init(test_neptune_logger):
+ mock_state = MagicMock()
+ mock_state.run_name = 'dummy-run-name' # should appear in sys/tags
+
+ test_neptune_logger.init(state=mock_state, logger=MagicMock())
+
+ assert test_neptune_logger.neptune_run is not None
+
+ test_neptune_logger.neptune_run.sync()
+ assert test_neptune_logger.neptune_run[NeptuneLogger.integration_version_key].fetch() == __version__
+ assert test_neptune_logger.neptune_run['sys/name'].fetch() == 'dummy-run-name'
+ assert test_neptune_logger.base_handler['rank'].fetch() == 0
+
+
+@device('cpu')
+def test_neptune_logging(device, test_neptune_logger):
+
+ dataset_size = 64
+ batch_size = 4
+ num_batches = 4
+ eval_interval = '1ba'
+
+ trainer = Trainer(model=SimpleConvModel(),
+ loggers=test_neptune_logger,
+ train_dataloader=DataLoader(RandomImageDataset(size=dataset_size), batch_size),
+ eval_dataloader=DataLoader(RandomImageDataset(size=dataset_size), batch_size),
+ max_duration=f'{num_batches}ba',
+ eval_interval=eval_interval,
+ device=device)
+ trainer.fit()
+
+ assert test_neptune_logger.neptune_run is not None
+ assert test_neptune_logger.base_handler is not None
+
+ for metric_name in [
+ 'metrics/train/MulticlassAccuracy', 'metrics/eval/MulticlassAccuracy', 'metrics/eval/CrossEntropy',
+ 'loss/train/total'
+ ]:
+ path = f'{test_neptune_logger._base_namespace}/{test_neptune_logger.metric_namespace}/{metric_name}'
+ assert test_neptune_logger.neptune_run.exists(path)
+
+ for hyperparam_name in ['node_name', 'num_cpus_per_node', 'num_nodes', 'rank_zero_seed']:
+ path = f'{test_neptune_logger._base_namespace}/{test_neptune_logger.hyperparam_namespace}/{hyperparam_name}'
+ assert test_neptune_logger.neptune_run.exists(path)
+
+ assert test_neptune_logger.base_handler['hyperparameters/num_nodes'].fetch() == 1
+
+
+@pytest.mark.gpu
+@pytest.mark.world_size(1, 2)
+def test_upload_and_download_file(test_neptune_logger, tmp_path, dummy_state):
+ neptune_artifact_name = 'test-neptune-artifact-' + str(uuid.uuid4())
+ tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path))
+ save_folder = Path(tmp_paths[0])
+ file_content = 'hello from Neptune!'
+
+ dummy_neptune_artifact_path = save_folder / 'neptune_artifact.txt'
+ if dist.get_global_rank() == 0:
+ with open(dummy_neptune_artifact_path, 'w+') as f:
+ f.write(file_content)
+
+ test_neptune_logger.upload_file(state=dummy_state,
+ file_path=dummy_neptune_artifact_path,
+ remote_file_name=neptune_artifact_name)
+
+ dist.barrier()
+
+ assert test_neptune_logger.neptune_run.exists(f'{test_neptune_logger._base_namespace}/{neptune_artifact_name}')
+
+ dst_path = save_folder / 'neptune_artifact'
+
+ test_neptune_logger.download_file(
+ remote_file_name=neptune_artifact_name,
+ destination=str(dst_path),
+ )
+
+ assert dst_path.exists()
+
+ with open(str(dst_path), 'r') as fp:
+ assert fp.read() == file_content
+
+
+def test_neptune_log_image(test_neptune_logger):
+ pytest.importorskip('neptune', reason='neptune is optional')
+
+ with patch('neptune.attributes.FileSeries.extend', MagicMock()) as mock_extend:
+ image_variants = [
+ (torch.rand(4, 4), False), # 2D image
+ (torch.rand(2, 3, 4, 4), False), # multiple images, not channels last
+ (torch.rand(2, 3, 4, 4, dtype=torch.float64), False), # same as above but with float64
+ (torch.rand(3, 4, 4), False), # with channels, not channels last
+ ([torch.rand(4, 4, 3)], True), # with channels, channels last
+ (torch.rand(2, 4, 4, 3), True), # multiple images, channels last
+ ([torch.rand(4, 4, 3), torch.rand(4, 4, 3)], True) # multiple images in list
+ ]
+
+ expected_num_images_total = 0
+ for (images, channels_last) in image_variants:
+ if isinstance(images, Sequence):
+ expected_num_images = len(images)
+ np_images = [image.to(torch.float32).numpy() for image in images]
+
+ else:
+ expected_num_images = 1 if images.ndim < 4 else images.shape[0]
+ np_images = images.to(torch.float32).numpy()
+ test_neptune_logger.log_images(images=images, channels_last=channels_last)
+ test_neptune_logger.log_images(images=np_images, channels_last=channels_last)
+
+ expected_num_images *= 2 # One set of torch tensors, one set of numpy arrays
+ expected_num_images_total += expected_num_images
+
+ test_neptune_logger.post_close()
+ assert mock_extend.call_count == 2 * len(image_variants) # One set of torch tensors, one set of numpy arrays
diff --git a/tests/loggers/test_wandb_logger.py b/tests/loggers/test_wandb_logger.py
index 1ccfc5e53a..c9cfe0fc6c 100644
--- a/tests/loggers/test_wandb_logger.py
+++ b/tests/loggers/test_wandb_logger.py
@@ -247,6 +247,7 @@ def test_wandb_log_metrics(test_wandb_logger):
@pytest.mark.parametrize('callback_cls', get_cbs_and_marks(callbacks=True))
+@pytest.mark.filterwarnings('ignore::UserWarning')
def test_logged_data_is_json_serializable(callback_cls: Type[Callback]):
"""Test that all logged data is json serializable, which is a requirement to use wandb."""
pytest.importorskip('wandb', reason='wandb is optional')
diff --git a/tests/metrics/metric_setter_callback.py b/tests/metrics/metric_setter_callback.py
index 6b90c26bfe..63ec8db305 100644
--- a/tests/metrics/metric_setter_callback.py
+++ b/tests/metrics/metric_setter_callback.py
@@ -60,6 +60,7 @@ def _update_metrics(self, state: State):
# assert for pyright error: "module_to_device" is not a known member of "None"
assert self.device is not None
self.device.module_to_device(raw_metric)
+ assert state.train_metrics is not None
if self.dataloader_label == 'train':
state.train_metrics[self.monitor] = raw_metric
else:
diff --git a/tests/metrics/test_current_metrics.py b/tests/metrics/test_current_metrics.py
index d5315e3993..0f75349f9c 100644
--- a/tests/metrics/test_current_metrics.py
+++ b/tests/metrics/test_current_metrics.py
@@ -29,12 +29,14 @@ def batch_end(self, state: State, logger: Logger) -> None:
# The metric should be computed and updated on state every batch.
del logger # unused
# assuming that at least one sample was correctly classified
+ assert state.train_metrics is not None
assert state.train_metrics['MulticlassAccuracy'].compute() != 0.0
self._train_batch_end_train_accuracy = state.train_metrics['MulticlassAccuracy']
def epoch_end(self, state: State, logger: Logger) -> None:
# The metric at epoch end should be the same as on batch end.
del logger # unused
+ assert state.train_metrics is not None
assert state.train_metrics['MulticlassAccuracy'].compute() == self._train_batch_end_train_accuracy
def eval_end(self, state: State, logger: Logger) -> None:
@@ -85,6 +87,7 @@ def test_current_metrics(eval_interval: str,):
return
# Validate the metrics
+ assert trainer.state.train_metrics is not None
assert trainer.state.train_metrics['MulticlassAccuracy'].compute() != 0.0
if compute_val_metrics:
diff --git a/tests/metrics/test_nlp_metrics.py b/tests/metrics/test_nlp_metrics.py
index a37e53ca8f..e31cd4d410 100644
--- a/tests/metrics/test_nlp_metrics.py
+++ b/tests/metrics/test_nlp_metrics.py
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import math
+from typing import Optional
import pytest
import torch
@@ -10,8 +11,9 @@
from composer.metrics.nlp import (BinaryF1Score, InContextLearningCodeEvalAccuracy,
InContextLearningExpectedCalibrationError, InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
- InContextLearningMCExpectedCalibrationError, InContextLearningMultipleChoiceAccuracy,
- InContextLearningQAAccuracy, LanguageCrossEntropy, LanguagePerplexity, MaskedAccuracy)
+ InContextLearningMCExpectedCalibrationError, InContextLearningMetric,
+ InContextLearningMultipleChoiceAccuracy, InContextLearningQAAccuracy,
+ LanguageCrossEntropy, LanguagePerplexity, MaskedAccuracy)
@pytest.mark.parametrize('ignore_index', [-100])
@@ -53,7 +55,7 @@ def test_masked_accuracy(ignore_index, num_classes):
@pytest.mark.parametrize('sequence_length', [128])
@pytest.mark.parametrize('num_classes', [2, 10])
@pytest.mark.parametrize('minibatch_size', [56, 256, 768])
-def test_cross_entropy(batch_size: float, ignore_index: int, sequence_length: int, num_classes: int,
+def test_cross_entropy(batch_size: float, ignore_index: Optional[int], sequence_length: int, num_classes: int,
minibatch_size: int):
"""Sanity check to make sure that batched CrossEntropyLoss matches the expected performance.
@@ -71,15 +73,15 @@ def test_cross_entropy(batch_size: float, ignore_index: int, sequence_length: in
generated_preds = torch.randn((batch_size, sequence_length, num_classes))
generated_true = torch.randint(low=0, high=num_classes, size=(batch_size, sequence_length))
+ assert ignore_index is not None
torchmetrics_xent = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index)
ce_with_keys_metric = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index)
- if ignore_index is not None:
- labels_mask = torch.rand((batch_size, sequence_length))
- labels_mask[labels_mask > 0.8] = 1
- labels_mask[labels_mask <= 0.8] = 0
- labels_mask = labels_mask.bool()
- generated_true[labels_mask] = ignore_index
+ labels_mask = torch.rand((batch_size, sequence_length))
+ labels_mask[labels_mask > 0.8] = 1
+ labels_mask[labels_mask <= 0.8] = 0
+ labels_mask = labels_mask.bool()
+ generated_true[labels_mask] = ignore_index
num_batches = math.ceil(batch_size / minibatch_size)
for batch_idx in range(num_batches):
@@ -171,6 +173,53 @@ def test_language_perplexity():
assert torch.equal(torch.exp(ce), perplexity)
+def test_in_context_learning_rename_args_no_op():
+ batch = {'input': [1, 2, 3]}
+ outputs = torch.Tensor([12, 13, 14])
+ labels = torch.Tensor([0, 1, 0])
+ batch, outputs, labels = InContextLearningMetric.rename_args(batch=batch, outputs=outputs, labels=labels)
+ assert batch == {'input': [1, 2, 3]}
+ assert torch.all(torch.eq(outputs, torch.tensor([12, 13, 14])))
+ assert torch.all(torch.eq(labels, torch.tensor([0, 1, 0])))
+
+
+def test_in_context_learning_rename_args_output_and_output_logits():
+ batch = {'input': [1, 2, 3]}
+ outputs = torch.Tensor([12, 13, 14])
+ output_logits = torch.Tensor([.1, .2, .3])
+ labels = torch.Tensor([0, 1, 0])
+ with pytest.raises(ValueError):
+ _, _, _ = InContextLearningMetric.rename_args(batch=batch,
+ outputs=outputs,
+ labels=labels,
+ output_logits=output_logits)
+
+
+def test_in_context_learning_rename_args_rename_output_logits():
+ batch = {'input': [1, 2, 3]}
+ output_logits = torch.Tensor([.1, .2, .3])
+ labels = torch.Tensor([0, 1, 0])
+ batch, outputs, labels = InContextLearningMetric.rename_args(batch=batch,
+ labels=labels,
+ output_logits=output_logits)
+ assert batch == {'input': [1, 2, 3]}
+ assert torch.all(torch.eq(outputs, torch.Tensor([.1, .2, .3]))) # pyright: ignore [reportGeneralTypeIssues]
+ assert torch.all(torch.eq(labels, torch.tensor([0, 1, 0])))
+
+
+def test_in_context_learning_rename_args_fail_on_no_label():
+ batch = {'input': [1, 2, 3]}
+ output_logits = torch.Tensor([.1, .2, .3])
+ with pytest.raises(ValueError):
+ _, _, _ = InContextLearningMetric.rename_args(batch=batch, output_logits=output_logits)
+
+
+def test_in_context_learning_rename_args_fail_on_no_output():
+ batch = {'input': [1, 2, 3]}
+ with pytest.raises(ValueError):
+ _, _, _ = InContextLearningMetric.rename_args(batch=batch)
+
+
def test_in_context_learning_lm_accuracy(tiny_gpt2_tokenizer):
contexts = ['The dog is', 'I love to eat', 'I hate', 'The weather is']
continuations = [' furry', ' pie', ' long lines', ' snowy']
@@ -237,12 +286,12 @@ def test_in_context_learning_qa_accuracy():
def test_in_context_learning_qa_cot_accuracy():
outputs = [
- 'chain of thought ### Correct but then some more text', 'Incorrect',
- 'chain of thought ### the CORREct with weird casing and spacing',
+ 'chain of thought ### Correct but then some more text\n\nanother chain of thought ### Incorrect answer this time',
+ 'Incorrect', 'chain of thought ### the CORREct with weird casing and spacing',
'incorrect chain of thought delimiter ## Correct but wrong delimiter'
]
labels = [['Correct'], ['blah', 'blah2'], ['blah', 'correct'], ['correct']]
- batch = {'cot_delimiter': ' ### ', 'labels': labels}
+ batch = {'cot_delimiter': ' ### ', 'labels': labels, 'do_normalization': True, 'stopping_criteria': '\n\n'}
metric = InContextLearningQAAccuracy()
metric.update(outputs, labels, batch)
diff --git a/tests/models/test_bert.py b/tests/models/test_bert.py
deleted file mode 100644
index 82caa80f45..0000000000
--- a/tests/models/test_bert.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-import pytest
-from torch.utils.data import DataLoader
-
-from composer.models.bert import create_bert_classification, create_bert_mlm
-from composer.trainer import Trainer
-from tests.common.datasets import RandomTextClassificationDataset, RandomTextLMDataset
-
-
-def test_bert_mlm_hf_factory(tiny_bert_config, tiny_bert_tokenizer, monkeypatch):
- transformers = pytest.importorskip('transformers')
- monkeypatch.setattr('transformers.AutoConfig.from_pretrained', lambda x: tiny_bert_config)
- bert_composer_model = create_bert_mlm(use_pretrained=False,
- pretrained_model_name='dummy',
- model_config=None,
- tokenizer_name=None,
- gradient_checkpointing=False)
-
- train_dataset = RandomTextLMDataset(size=8,
- vocab_size=tiny_bert_tokenizer.vocab_size,
- sequence_length=8,
- use_keys=True)
- collator = transformers.DataCollatorForLanguageModeling(tokenizer=tiny_bert_tokenizer,
- mlm=True,
- mlm_probability=0.15)
- train_dataloader = DataLoader(train_dataset, batch_size=4, collate_fn=collator)
-
- trainer = Trainer(model=bert_composer_model, train_dataloader=train_dataloader, max_duration='1ep')
- trainer.fit()
-
- assert trainer.state.train_metrics['LanguageCrossEntropy'].compute() > 0.0
-
-
-def test_bert_classification_hf_factory(tiny_bert_config, tiny_bert_tokenizer, monkeypatch):
- pytest.importorskip('transformers')
-
- def config_patch(x, num_labels):
- tiny_bert_config.num_labels = num_labels
- return tiny_bert_config
-
- monkeypatch.setattr('transformers.AutoConfig.from_pretrained', config_patch)
- bert_composer_model = create_bert_classification(use_pretrained=False,
- pretrained_model_name='dummy',
- model_config=None,
- tokenizer_name=None,
- gradient_checkpointing=False,
- num_labels=3)
-
- train_dataset = RandomTextClassificationDataset(size=8,
- vocab_size=tiny_bert_tokenizer.vocab_size,
- sequence_length=8,
- num_classes=3,
- use_keys=True)
- train_dataloader = DataLoader(train_dataset, batch_size=4)
-
- trainer = Trainer(model=bert_composer_model, train_dataloader=train_dataloader, max_duration='1ep')
- trainer.fit()
-
- assert trainer.state.train_metrics['MulticlassAccuracy'].compute() > 0.0
diff --git a/tests/models/test_efficientnet.py b/tests/models/test_efficientnet.py
deleted file mode 100644
index a11dccc87b..0000000000
--- a/tests/models/test_efficientnet.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-import pytest
-import torch
-
-from composer.models.efficientnetb0.efficientnets import EfficientNet
-
-
-@pytest.mark.gpu
-def test_efficientb0_activate_shape():
- # Running this test on cuda as convolutions are slow on CPU
- random_input = torch.rand(2, 3, 224, 224).cuda()
-
- model = EfficientNet.get_model_from_name(
- 'efficientnet-b0',
- num_classes=1000,
- drop_connect_rate=0.2,
- ).cuda()
- # Test Stem
- out = model.conv_stem(random_input)
- out = model.bn1(out)
- out = model.act1(out)
- assert out.shape == (2, 32, 112, 112)
-
- # Test each block, shapes found at Table 1 of EfficientNet paper
- block_act_shape = [
- (2, 16, 112, 112),
- (2, 24, 56, 56),
- (2, 24, 56, 56),
- (2, 40, 28, 28),
- (2, 40, 28, 28),
- (2, 80, 14, 14),
- (2, 80, 14, 14),
- (2, 80, 14, 14),
- (2, 112, 14, 14),
- (2, 112, 14, 14),
- (2, 112, 14, 14),
- (2, 192, 7, 7),
- (2, 192, 7, 7),
- (2, 192, 7, 7),
- (2, 192, 7, 7),
- (2, 320, 7, 7),
- ]
- for i, block in enumerate(model.blocks):
- out = block(out)
- assert out.shape == block_act_shape[i]
-
- out = model.conv_head(out)
- assert out.shape == (2, 1280, 7, 7)
diff --git a/tests/models/test_gpt2.py b/tests/models/test_gpt2.py
deleted file mode 100644
index 7bbb878e5e..0000000000
--- a/tests/models/test_gpt2.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-import pytest
-from torch.utils.data import DataLoader
-
-from composer.models.gpt2 import create_gpt2
-from composer.trainer import Trainer
-from tests.common.datasets import RandomTextLMDataset
-
-
-def test_gpt2_hf_factory(tiny_gpt2_config, tiny_gpt2_tokenizer, monkeypatch):
- transformers = pytest.importorskip('transformers')
- monkeypatch.setattr('transformers.AutoConfig.from_pretrained', lambda x: tiny_gpt2_config)
- gpt2_composer_model = create_gpt2(use_pretrained=False,
- pretrained_model_name='dummy',
- model_config=None,
- tokenizer_name=None,
- gradient_checkpointing=False)
-
- train_dataset = RandomTextLMDataset(size=8,
- vocab_size=tiny_gpt2_tokenizer.vocab_size,
- sequence_length=8,
- use_keys=True)
- collator = transformers.DataCollatorForLanguageModeling(tokenizer=tiny_gpt2_tokenizer, mlm=False)
- train_dataloader = DataLoader(train_dataset, batch_size=4, collate_fn=collator)
-
- trainer = Trainer(model=gpt2_composer_model, train_dataloader=train_dataloader, max_duration='1ep')
- trainer.fit()
-
- assert trainer.state.train_metrics['LanguagePerplexity'].compute() > 0.0
diff --git a/tests/models/test_hf_model.py b/tests/models/test_hf_model.py
index 0f6076116f..e677941e9e 100644
--- a/tests/models/test_hf_model.py
+++ b/tests/models/test_hf_model.py
@@ -6,7 +6,7 @@
import tempfile
from contextlib import nullcontext
from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
from unittest.mock import patch
from urllib.parse import urlparse
@@ -26,9 +26,48 @@
from tests.common.datasets import RandomTextClassificationDataset, RandomTextLMDataset, RandomTextRegressionDataset
from tests.common.markers import device, world_size
from tests.common.models import (configure_tiny_bert_model, configure_tiny_bert_tokenizer, configure_tiny_gpt2_model,
- configure_tiny_gpt2_tokenizer, configure_tiny_t5_model, configure_tiny_t5_tokenizer)
+ configure_tiny_gpt2_tokenizer, configure_tiny_mistral_model,
+ configure_tiny_mistral_tokenizer, configure_tiny_t5_model, configure_tiny_t5_tokenizer)
from tests.loggers.test_remote_uploader_downloader import DummyObjectStore
+if TYPE_CHECKING:
+ from peft import PeftConfig
+
+
+def _gpt2_peft_config():
+ pytest.importorskip('peft')
+ from peft import get_peft_config
+
+ peft_config = get_peft_config({
+ 'peft_type': 'LORA',
+ 'task_type': 'CAUSAL_LM',
+ 'target_modules': ['c_attn'],
+ 'fan_in_fan_out': True,
+ })
+ return peft_config
+
+
+@pytest.fixture
+def gpt2_peft_config():
+ return _gpt2_peft_config()
+
+
+def _mistral_peft_config():
+ pytest.importorskip('peft')
+ from peft import get_peft_config
+
+ peft_config = get_peft_config({
+ 'peft_type': 'LORA',
+ 'task_type': 'CAUSAL_LM',
+ 'target_modules': ['up_proj'],
+ })
+ return peft_config
+
+
+@pytest.fixture
+def mistral_peft_config():
+ return _mistral_peft_config()
+
def test_hf_tokenizer_save(tmp_path: Path, tiny_bert_model, tiny_bert_tokenizer):
transformers = pytest.importorskip('transformers')
@@ -98,6 +137,7 @@ def test_hf_train_eval_predict(num_classes: int, tiny_bert_config):
trainer.eval()
# Check that there is some train/eval accuracy
+ assert trainer.state.train_metrics is not None
assert trainer.state.train_metrics['MulticlassAccuracy'].compute() != 0.0
assert trainer.state.eval_metrics['eval']['MulticlassAccuracy'].compute() != 0.0
@@ -153,6 +193,7 @@ def test_hf_train_eval_predict_regression(tiny_deberta_config):
trainer.eval()
# Check that there is some train/eval accuracy
+ assert trainer.state.train_metrics is not None
assert trainer.state.train_metrics['PearsonCorrCoef'].compute() != 0.0
assert trainer.state.eval_metrics['eval']['PearsonCorrCoef'].compute() != 0.0
@@ -431,14 +472,33 @@ def get_lm_trainer(hf_model,
device_train_microbatch_size: Optional[int] = None,
batch_size: int = 4,
sequence_length: int = 4,
- size: int = 4):
+ size: int = 4,
+ peft_config: Optional['PeftConfig'] = None,
+ should_save_peft_only: bool = False):
transformers = pytest.importorskip('transformers')
metrics: List[Metric] = [LanguageCrossEntropy(ignore_index=-100)]
if not is_conditional_generation:
metrics.append(MaskedAccuracy(ignore_index=-100))
- model = HuggingFaceModel(hf_model, tokenizer=hf_tokenizer, metrics=metrics, use_logits=True)
+ model = HuggingFaceModel(
+ hf_model,
+ tokenizer=hf_tokenizer,
+ metrics=metrics,
+ use_logits=True,
+ peft_config=peft_config,
+ should_save_peft_only=should_save_peft_only,
+ )
+
+ # On torch 2.0, fsdp wrapped modules can not have both frozen and unfrozen params.
+ # On 2.1+, if you have use_orig_params=True, they can. So we need a special case for the tests here.
+ if version.parse(torch.__version__) < version.parse('2.1.0') and peft_config is not None:
+ for name, module in model.named_modules():
+ if 'lora' in name.lower() and 'default' in name.lower():
+ has_parameters = any(True for _ in module.parameters())
+ has_buffers = any(True for _ in module.buffers())
+ if has_parameters or has_buffers:
+ module._fsdp_wrap = True # type: ignore
vocab_size = hf_model.config.vocab_size
sequence_length = 4
@@ -475,8 +535,13 @@ def get_lm_trainer(hf_model,
collate_fn=collator,
sampler=dist.get_sampler(train_dataset))
+ from composer.optim import DecoupledAdamW
+
+ optimizer = DecoupledAdamW(model.parameters(), lr=1e-3)
+
in_memory_logger = InMemoryLogger()
trainer = Trainer(model=model,
+ optimizers=optimizer,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_duration='1ep',
@@ -865,8 +930,6 @@ def test_encoder_decoder(tiny_t5_model, tiny_t5_tokenizer):
@pytest.mark.gpu
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='requires PyTorch 1.13 or higher')
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_hf_fsdp(tiny_bert_config, tiny_bert_tokenizer):
transformers = pytest.importorskip('transformers')
@@ -909,6 +972,7 @@ def test_separate_eval_metrics(tiny_bert_model, tiny_bert_tokenizer):
@pytest.mark.parametrize('checkpoint_upload_folder', [None, 's3://checkpoints-bucket/'])
@pytest.mark.parametrize('local_save_filename', [None, 'local-checkpoint.pt'])
+@pytest.mark.filterwarnings('ignore:TypedStorage is deprecated.*:UserWarning')
def test_write_hf_from_composer(checkpoint_upload_folder, local_save_filename, tiny_bert_model, tiny_bert_tokenizer,
tmp_path):
transformers = pytest.importorskip('transformers')
@@ -943,6 +1007,7 @@ def test_write_hf_from_composer(checkpoint_upload_folder, local_save_filename, t
check_hf_model_equivalence(tiny_bert_model, loaded_hf_model)
+@pytest.mark.filterwarnings('ignore:TypedStorage is deprecated.*:UserWarning')
def test_write_hf_from_composer_direct(tiny_bert_tokenizer, tmp_path):
# tests that the logic to write out a huggingface checkpoint from a composer checkpoint
# still works when the huggingface model is instantiated directly rather than using from_pretrained
@@ -1028,9 +1093,6 @@ def test_embedding_resizing(tiny_bert_model, tiny_bert_tokenizer, embedding_resi
@pytest.mark.parametrize('hf_model,hf_tokenizer', [(configure_tiny_gpt2_model, configure_tiny_gpt2_tokenizer),
(configure_tiny_t5_model, configure_tiny_t5_tokenizer)])
def test_generate(device, world_size, hf_model, hf_tokenizer, use_fsdp):
- if use_fsdp and version.parse(torch.__version__) < version.parse('1.13.0'):
- pytest.skip('FSDP requires torch >= 1.13.0')
-
transformers = pytest.importorskip('transformers')
if device == 'cpu' and use_fsdp:
pytest.skip('FSDP is not supported on CPU.')
@@ -1074,12 +1136,10 @@ def test_generate(device, world_size, hf_model, hf_tokenizer, use_fsdp):
generation1 = model.generate(**input_dict, max_new_tokens=5, pad_token_id=hf_tokenizer.pad_token_id)
generation2 = model.generate(**input_dict, max_new_tokens=3, pad_token_id=hf_tokenizer.pad_token_id)
- assert generation1.shape == (2,
- (input_dict['input_ids'].shape[1] if not hf_model.config.is_encoder_decoder else 1) +
- 5)
- assert generation2.shape == (2,
- (input_dict['input_ids'].shape[1] if not hf_model.config.is_encoder_decoder else 1) +
- 3)
+ generation1_dim2 = (input_dict['input_ids'].shape[1] if not hf_model.config.is_encoder_decoder else 1) + 5
+ assert generation1.shape == (2, generation1_dim2) # pyright: ignore[reportGeneralTypeIssues]
+ generation2_dim2 = (input_dict['input_ids'].shape[1] if not hf_model.config.is_encoder_decoder else 1) + 3
+ assert generation2.shape == (2, generation2_dim2) # pyright: ignore[reportGeneralTypeIssues]
decoded_generation1 = hf_tokenizer.batch_decode(generation1, skip_special_tokens=True)
decoded_generation2 = hf_tokenizer.batch_decode(generation2, skip_special_tokens=True)
@@ -1095,8 +1155,6 @@ def test_generate(device, world_size, hf_model, hf_tokenizer, use_fsdp):
@pytest.mark.parametrize('hf_model,hf_tokenizer', [(configure_tiny_gpt2_model, configure_tiny_gpt2_tokenizer),
(configure_tiny_t5_model, configure_tiny_t5_tokenizer)])
def test_eval_forward_generate(device, world_size, hf_model, hf_tokenizer, use_fsdp):
- if use_fsdp and version.parse(torch.__version__) < version.parse('1.13.0'):
- pytest.skip('FSDP requires torch >= 1.13.0')
transformers = pytest.importorskip('transformers')
if device == 'cpu' and use_fsdp:
pytest.skip('FSDP is not supported on CPU.')
@@ -1148,3 +1206,229 @@ def test_eval_forward_generate(device, world_size, hf_model, hf_tokenizer, use_f
assert len(generation1) == len(generation2) == 2
assert all(isinstance(decoded_generation, str) for decoded_generation in generation1)
assert all(isinstance(decoded_generation, str) for decoded_generation in generation2)
+
+
+@pytest.mark.parametrize('peft_type', ['LORA', 'loRa'])
+@pytest.mark.parametrize('task_type', ['CAUSAL_LM', 'causal_lm'])
+def test_peft_init(peft_type: str, task_type: str, tiny_gpt2_model, gpt2_peft_config):
+ pytest.importorskip('peft')
+ from peft import PeftModelForCausalLM
+
+ peft_config = copy.deepcopy(gpt2_peft_config)
+ peft_config.peft_type = peft_type
+ peft_config.task_type = task_type
+
+ original_model = copy.deepcopy(tiny_gpt2_model)
+
+ hf_model = HuggingFaceModel(tiny_gpt2_model, peft_config=peft_config)
+ assert isinstance(hf_model.model, PeftModelForCausalLM)
+ assert hf_model.model.peft_config['default'].peft_type == 'LORA'
+ assert hf_model.model.peft_config['default'].task_type == 'CAUSAL_LM'
+ assert hf_model.model.config == original_model.config
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2+')
+def test_peft_init_errors(tiny_gpt2_model, gpt2_peft_config):
+ pytest.importorskip('peft')
+ peft_config = copy.deepcopy(gpt2_peft_config)
+ peft_config.peft_type = 'NOT_LORA'
+
+ with pytest.raises(ValueError):
+ _ = HuggingFaceModel(tiny_gpt2_model, peft_config=peft_config)
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2+')
+def test_peft_init_not_installed(tiny_gpt2_model, gpt2_peft_config):
+ pytest.importorskip('peft')
+
+ with patch('composer.models.huggingface.peft_installed', False):
+ with pytest.raises(ImportError):
+ from composer.models import HuggingFaceModel
+ _ = HuggingFaceModel(tiny_gpt2_model, peft_config=gpt2_peft_config)
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2+')
+@pytest.mark.parametrize('should_save_peft_only', [True, False])
+def test_peft_trains_and_loads(tiny_gpt2_model, tiny_gpt2_tokenizer, gpt2_peft_config, tmp_path, should_save_peft_only):
+ pytest.importorskip('peft')
+
+ trainer = get_lm_trainer(
+ tiny_gpt2_model,
+ tiny_gpt2_tokenizer,
+ str(tmp_path),
+ peft_config=gpt2_peft_config,
+ device_train_microbatch_size=1,
+ mlm=False,
+ should_save_peft_only=should_save_peft_only,
+ )
+ trainer.fit()
+
+ load_trainer = get_lm_trainer(
+ tiny_gpt2_model,
+ tiny_gpt2_tokenizer,
+ str(tmp_path),
+ peft_config=gpt2_peft_config,
+ device_train_microbatch_size=1,
+ mlm=False,
+ load_path=str(tmp_path / 'hf-checkpoint.pt'),
+ should_save_peft_only=should_save_peft_only,
+ )
+
+ for p1, p2 in zip(trainer.state.model.parameters(), load_trainer.state.model.parameters()):
+ torch.testing.assert_close(p1, p2)
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2+')
+@pytest.mark.parametrize('model,tokenizer,peft_config', [
+ (configure_tiny_gpt2_model, configure_tiny_gpt2_tokenizer, _gpt2_peft_config()),
+ (configure_tiny_mistral_model, configure_tiny_mistral_tokenizer, _mistral_peft_config()),
+])
+def test_peft_generate(model, tokenizer, peft_config):
+ pytest.importorskip('peft')
+
+ model = model()
+ tokenizer = tokenizer()
+
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ hf_model = HuggingFaceModel(model, tokenizer=tokenizer, peft_config=peft_config)
+
+ input_dict = tokenizer(['hello', 'goodbyes'], return_tensors='pt', padding=True)
+ hf_model.generate(**input_dict, max_new_tokens=5, pad_token_id=tokenizer.pad_token_id)
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2+')
+def test_peft_metadata(tiny_gpt2_model, tiny_gpt2_tokenizer, gpt2_peft_config):
+ pytest.importorskip('peft')
+
+ from peft import get_peft_config
+
+ hf_model = HuggingFaceModel(tiny_gpt2_model, tokenizer=tiny_gpt2_tokenizer, peft_config=gpt2_peft_config)
+ metadata = hf_model.get_metadata()
+ loaded_peft_config = get_peft_config(metadata['model']['peft_config']['content'])
+
+ assert loaded_peft_config == gpt2_peft_config
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2+')
+@pytest.mark.parametrize('should_save_peft_only', [True, False])
+def test_peft_write_hf_from_composer(tiny_gpt2_model, tiny_gpt2_tokenizer, gpt2_peft_config, tmp_path,
+ should_save_peft_only):
+ peft = pytest.importorskip('peft')
+ transformers = pytest.importorskip('transformers')
+
+ # Simulate a local model instead of a hub model
+ tiny_gpt2_model.save_pretrained(tmp_path / 'hf-save-to-load')
+ tiny_gpt2_model = transformers.AutoModelForCausalLM.from_pretrained(tmp_path / 'hf-save-to-load')
+
+ trainer = get_lm_trainer(
+ tiny_gpt2_model,
+ tiny_gpt2_tokenizer,
+ str(tmp_path),
+ peft_config=gpt2_peft_config,
+ device_train_microbatch_size=1,
+ mlm=False,
+ should_save_peft_only=should_save_peft_only,
+ )
+ trainer.fit()
+
+ from composer.models.huggingface import write_huggingface_pretrained_from_composer_checkpoint
+ write_huggingface_pretrained_from_composer_checkpoint(str(tmp_path / 'hf-checkpoint.pt'),
+ tmp_path / 'hf-save-pretrained')
+
+ # Test we can load back in using transformers interface
+ loaded_hf_model = transformers.AutoModelForCausalLM.from_pretrained(str(tmp_path / 'hf-save-pretrained'))
+ for p1, p2 in zip(trainer.state.model.model.parameters(), loaded_hf_model.parameters()):
+ torch.testing.assert_close(p1, p2)
+
+ # Test we can load back in using peft interface
+ loaded_peft_model = peft.PeftModelForCausalLM.from_pretrained(tiny_gpt2_model, str(tmp_path / 'hf-save-pretrained'))
+ for p1, p2 in zip(trainer.state.model.model.parameters(), loaded_peft_model.parameters()):
+ torch.testing.assert_close(p1, p2)
+
+
+@pytest.mark.gpu
+@world_size(2)
+@pytest.mark.parametrize('should_save_peft_only', [True, False])
+def test_peft_fsdp_trains(tiny_gpt2_model, tiny_gpt2_tokenizer, gpt2_peft_config, tmp_path, world_size,
+ should_save_peft_only):
+ pytest.importorskip('peft')
+
+ fsdp_config = {
+ 'sharding_strategy': 'FULL_SHARD',
+ 'cpu_offload': False,
+ 'mixed_precision': 'PURE',
+ 'backward_prefetch': 'BACKWARD_PRE',
+ 'activation_checkpointing': False,
+ 'activation_cpu_offload': False,
+ 'verbose': False
+ }
+
+ stashed_model = copy.deepcopy(tiny_gpt2_model)
+
+ trainer = get_lm_trainer(
+ tiny_gpt2_model,
+ tiny_gpt2_tokenizer,
+ str(tmp_path / 'trainer1'),
+ peft_config=gpt2_peft_config,
+ device_train_microbatch_size=1,
+ mlm=False,
+ fsdp_config=fsdp_config,
+ should_save_peft_only=should_save_peft_only,
+ )
+
+ for n, p in trainer.state.model.model.named_parameters():
+ if 'lora' in n:
+ assert p.requires_grad
+ else:
+ assert not p.requires_grad
+
+ trainer.fit()
+ trainer.close()
+
+ load_trainer = get_lm_trainer(
+ stashed_model,
+ tiny_gpt2_tokenizer,
+ str(tmp_path / 'trainer2'),
+ peft_config=gpt2_peft_config,
+ device_train_microbatch_size=1,
+ mlm=False,
+ load_path=str(tmp_path / 'trainer1' / 'hf-checkpoint.pt'),
+ fsdp_config=fsdp_config,
+ should_save_peft_only=should_save_peft_only,
+ )
+
+ for n, p in load_trainer.state.model.model.named_parameters():
+ if 'lora' in n:
+ assert p.requires_grad
+ else:
+ assert not p.requires_grad
+
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+
+ with FSDP.summon_full_params(trainer.state.model), FSDP.summon_full_params(load_trainer.state.model):
+ for p1, p2 in zip(trainer.state.model.parameters(), load_trainer.state.model.parameters()):
+ torch.testing.assert_close(p1, p2)
+
+ if dist.get_global_rank() == 0:
+ loaded_ckpt_1 = torch.load(str(tmp_path / 'trainer1' / 'hf-checkpoint.pt'))
+
+ # Check that only the LoRA parameters were saved
+ if should_save_peft_only:
+ assert all('lora' in k for k in loaded_ckpt_1['state']['model'].keys())
+ else:
+ assert not all('lora' in k for k in loaded_ckpt_1['state']['model'].keys())
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0'), reason='requires PyTorch 2+')
+def test_filtered_state_dict(tiny_gpt2_model, tiny_gpt2_tokenizer, gpt2_peft_config, tmp_path):
+ pytest.importorskip('peft')
+
+ hf_model = HuggingFaceModel(tiny_gpt2_model,
+ tokenizer=tiny_gpt2_tokenizer,
+ peft_config=gpt2_peft_config,
+ should_save_peft_only=True)
+ state_dict = hf_model.state_dict()
+
+ assert len(state_dict.keys()) == 4
diff --git a/tests/models/test_mmdet_model.py b/tests/models/test_mmdet_model.py
deleted file mode 100644
index fafeeb1ac5..0000000000
--- a/tests/models/test_mmdet_model.py
+++ /dev/null
@@ -1,200 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-import numpy as np
-import pytest
-import torch
-
-
-@pytest.fixture
-def mmdet_detection_batch():
- batch_size = 2
- num_labels_per_image = 20
- image_size = 224
- return {
- 'img_metas': [{
- 'filename': '../../data/coco/train2017/fake_img.jpg',
- 'ori_filename': 'fake_image.jpg',
- 'img_shape': (image_size, image_size, 3),
- 'ori_shape': (image_size, image_size, 3),
- 'pad_shape': (image_size, image_size, 3),
- 'scale_factor': np.array([1., 1., 1., 1.], dtype=np.float32)
- }] * batch_size,
- 'img':
- torch.zeros(batch_size, 3, image_size, image_size, dtype=torch.float32),
- 'gt_bboxes': [torch.zeros(num_labels_per_image, 4, dtype=torch.float32)] * batch_size,
- 'gt_labels': [torch.zeros(num_labels_per_image, dtype=torch.int64)] * batch_size
- }
-
-
-@pytest.fixture
-def mmdet_detection_eval_batch():
- # Eval settings for mmdetection datasets have an extra list around inputs.
- batch_size = 2
- num_labels_per_image = 20
- image_size = 224
- return {
- 'img_metas': [[{
- 'filename': '../../data/coco/train2017/fake_img.jpg',
- 'ori_filename': 'fake_image.jpg',
- 'img_shape': (image_size, image_size, 3),
- 'ori_shape': (image_size, image_size, 3),
- 'pad_shape': (image_size, image_size, 3),
- 'scale_factor': np.array([1., 1., 1., 1.], dtype=np.float32),
- }] * batch_size],
- 'img': [torch.zeros(batch_size, 3, image_size, image_size, dtype=torch.float32)],
- 'gt_bboxes': [[torch.zeros(num_labels_per_image, 4, dtype=torch.float32)] * batch_size],
- 'gt_labels': [[torch.zeros(num_labels_per_image, dtype=torch.int64)] * batch_size]
- }
-
-
-@pytest.fixture
-def yolox_config():
- # from https://github.com/open-mmlab/mmdetection/blob/master/configs/yolox/yolox_s_8x8_300e_coco.py
- return dict(
- type='YOLOX',
- input_size=(640, 640),
- random_size_range=(15, 25),
- random_size_interval=10,
- backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
- neck=dict(type='YOLOXPAFPN', in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1),
- bbox_head=dict(type='YOLOXHead', num_classes=80, in_channels=128, feat_channels=128),
- train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
- # In order to align the source code, the threshold of the val phase is
- # 0.01, and the threshold of the test phase is 0.001.
- test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))
-
-
-@pytest.fixture
-def faster_rcnn_config():
- # modified from https://github.com/open-mmlab/mmdetection/blob/master/configs/_base_/models/faster_rcnn_r50_fpn.py
- return dict(
- type='FasterRCNN',
- backbone=dict(type='ResNet',
- depth=50,
- num_stages=4,
- out_indices=(0, 1, 2, 3),
- frozen_stages=1,
- norm_cfg=dict(type='BN', requires_grad=True),
- norm_eval=True,
- style='pytorch'),
- neck=dict(type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5),
- rpn_head=dict(type='RPNHead',
- in_channels=256,
- feat_channels=256,
- anchor_generator=dict(type='AnchorGenerator',
- scales=[8],
- ratios=[0.5, 1.0, 2.0],
- strides=[4, 8, 16, 32, 64]),
- bbox_coder=dict(type='DeltaXYWHBBoxCoder',
- target_means=[.0, .0, .0, .0],
- target_stds=[1.0, 1.0, 1.0, 1.0]),
- loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
- loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
- roi_head=dict(type='StandardRoIHead',
- bbox_roi_extractor=dict(type='SingleRoIExtractor',
- roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
- out_channels=256,
- featmap_strides=[4, 8, 16, 32]),
- bbox_head=dict(type='Shared2FCBBoxHead',
- in_channels=256,
- fc_out_channels=1024,
- roi_feat_size=7,
- num_classes=80,
- bbox_coder=dict(type='DeltaXYWHBBoxCoder',
- target_means=[0., 0., 0., 0.],
- target_stds=[0.1, 0.1, 0.2, 0.2]),
- reg_class_agnostic=False,
- loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
- loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
- # model training and testing settings
- train_cfg=dict(rpn=dict(assigner=dict(type='MaxIoUAssigner',
- pos_iou_thr=0.7,
- neg_iou_thr=0.3,
- min_pos_iou=0.3,
- match_low_quality=True,
- ignore_iof_thr=-1),
- sampler=dict(type='RandomSampler',
- num=256,
- pos_fraction=0.5,
- neg_pos_ub=-1,
- add_gt_as_proposals=False),
- allowed_border=-1,
- pos_weight=-1,
- debug=False),
- rpn_proposal=dict(nms_pre=2000,
- max_per_img=1000,
- nms=dict(type='nms', iou_threshold=0.7),
- min_bbox_size=0),
- rcnn=dict(assigner=dict(type='MaxIoUAssigner',
- pos_iou_thr=0.5,
- neg_iou_thr=0.5,
- min_pos_iou=0.5,
- match_low_quality=False,
- ignore_iof_thr=-1),
- sampler=dict(type='RandomSampler',
- num=512,
- pos_fraction=0.25,
- neg_pos_ub=-1,
- add_gt_as_proposals=True),
- pos_weight=-1,
- debug=False)),
- test_cfg=dict(
- rpn=dict(nms_pre=1000, max_per_img=1000, nms=dict(type='nms', iou_threshold=0.7), min_bbox_size=0),
- rcnn=dict(score_thr=0.05, nms=dict(type='nms', iou_threshold=0.5), max_per_img=100)
- # soft-nms is also supported for rcnn testing
- # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
- ))
-
-
-def test_mmdet_model_forward_yolox(mmdet_detection_batch, yolox_config):
- pytest.importorskip('mmdet')
-
- from mmcv import ConfigDict
- from mmdet.models import build_detector
-
- from composer.models import MMDetModel
-
- config = ConfigDict(yolox_config)
- # non pretrained model to avoid a slow test that downloads the weights.
- model = build_detector(config)
- model.init_weights()
- model = MMDetModel(model=model)
- out = model(mmdet_detection_batch)
- assert list(out.keys()) == ['loss_cls', 'loss_bbox', 'loss_obj']
-
-
-def test_mmdet_model_eval_forward_yolox(mmdet_detection_eval_batch, yolox_config):
- pytest.importorskip('mmdet')
-
- from mmcv import ConfigDict
- from mmdet.models import build_detector
-
- from composer.models import MMDetModel
-
- config = ConfigDict(yolox_config)
- # non pretrained model to avoid a slow test that downloads the weights.
- model = build_detector(config)
- model.init_weights()
- model = MMDetModel(model=model)
- out = model.eval_forward(mmdet_detection_eval_batch)
- assert len(out) == mmdet_detection_eval_batch['img'][0].shape[0] # batch size
- assert list(out[0].keys()) == ['labels', 'boxes', 'scores']
-
-
-def test_mmdet_model_forward_faster_rcnn(mmdet_detection_batch, faster_rcnn_config):
- pytest.importorskip('mmdet')
-
- from mmcv import ConfigDict
- from mmdet.models import build_detector
-
- from composer.models import MMDetModel
-
- config = ConfigDict(faster_rcnn_config)
-
- # non pretrained model to avoid a slow test that downloads the weights.
- model = build_detector(config)
- model.init_weights()
- model = MMDetModel(model=model)
- out = model(mmdet_detection_batch)
- assert list(out.keys()) == ['loss_rpn_cls', 'loss_rpn_bbox', 'loss_cls', 'acc', 'loss_bbox']
diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py
index 2ae9383d79..f13be17486 100644
--- a/tests/profiler/test_profiler.py
+++ b/tests/profiler/test_profiler.py
@@ -9,8 +9,10 @@
import pytest
import torch
from packaging import version
+from torch.profiler.profiler import ProfilerAction as TorchProfilerAction
-from composer.core import State
+from composer.core import Engine, Event, State, Timestamp
+from composer.loggers import Logger
from composer.profiler import Profiler, ProfilerAction, SystemProfiler, TorchProfiler, cyclic_schedule
from composer.profiler.utils import export_memory_timeline_html
@@ -170,3 +172,39 @@ def test_memory_timeline(tmp_path: pathlib.Path) -> None:
assert fig is not None, 'export_memory_timeline_html should return a figure when return_fig=True'
_, end = fig.gca().get_ylim()
assert round(end, 2) == 0.06
+
+
+def test_skip_first_after_resumption(minimal_state: State) -> None:
+ skip_first = 1
+ wait = 2
+ warmup = 3
+ active = 4
+ repeat = 1
+ schedule = cyclic_schedule(skip_first=skip_first, wait=wait, warmup=warmup, active=active, repeat=repeat)
+ mock_trace_handler = MagicMock()
+ profiler = Profiler(
+ trace_handlers=[mock_trace_handler],
+ schedule=schedule,
+ )
+ profiler.bind_to_state(minimal_state)
+ minimal_state.profiler = profiler
+
+ assert len(profiler._callbacks) >= 1
+ assert isinstance(profiler._callbacks[-1], TorchProfiler)
+ torch_profiler = profiler._callbacks[-1]
+
+ # Create torch.profiler.profile
+ logger = Logger(minimal_state)
+ engine = Engine(state=minimal_state, logger=logger)
+ engine.run_event(Event.INIT)
+ assert torch_profiler.profiler is not None
+
+ minimal_state.timestamp = Timestamp(batch_in_epoch=7)
+ assert torch_profiler.profiler.schedule(0) == TorchProfilerAction.RECORD
+
+ # Load checkpoint at batch 4
+ minimal_state.timestamp = Timestamp(batch_in_epoch=4)
+ engine.run_event(Event.BEFORE_LOAD)
+ engine.run_event(Event.AFTER_LOAD)
+ minimal_state.timestamp = Timestamp(batch_in_epoch=7)
+ assert torch_profiler.profiler.schedule(0) == TorchProfilerAction.WARMUP
diff --git a/tests/test_events.py b/tests/test_events.py
index 63bff245ba..c81feea0b0 100644
--- a/tests/test_events.py
+++ b/tests/test_events.py
@@ -5,7 +5,6 @@
import pytest
import torch
-from packaging import version
from torch.utils.data import DataLoader
from composer import Trainer
@@ -89,8 +88,6 @@ def get_trainer(self, precision='fp32', **kwargs):
id='gpu-fsdp',
marks=[
pytest.mark.gpu,
- pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='requires PyTorch 1.13 or higher'),
pytest.mark.filterwarnings('ignore::UserWarning'),
]),
])
@@ -153,6 +150,7 @@ def _assert_expected_event_calls(self, trainer: Trainer, eval_interval: Time, nu
expected_num_calls = {
Event.INIT: 1,
+ Event.BEFORE_LOAD: 1,
Event.AFTER_LOAD: 1,
Event.EPOCH_START: num_epochs,
Event.BATCH_START: total_steps,
diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py
index 3d31d36406..11ad2240d3 100644
--- a/tests/test_notebooks.py
+++ b/tests/test_notebooks.py
@@ -6,6 +6,7 @@
import os
from urllib.parse import urlparse
+import importlib_metadata
import pytest
import testbook
from testbook.client import TestbookNotebookClient
@@ -21,6 +22,16 @@
for nb in glob.glob(os.path.join(nb_root, '*.ipynb')) \
]
+try:
+ importlib_metadata.files('mosaicml')
+ package_name = 'mosaicml'
+except importlib_metadata.PackageNotFoundError:
+ try:
+ importlib_metadata.files('composer')
+ package_name = 'composer'
+ except importlib_metadata.PackageNotFoundError:
+ raise RuntimeError('Could not find the package under mosaicml or composer.')
+
def patch_notebooks():
import itertools
@@ -80,11 +91,15 @@ def modify_cell_source(tb: TestbookNotebookClient, notebook_name: str, cell_sour
cell_source = cell_source.replace('batch_size = 1024', 'batch_size = 64')
cell_source = cell_source.replace('download=True', 'download=False')
if notebook_name == 'auto_microbatching':
+ cell_source = cell_source.replace('batch_size = 2048', 'batch_size = 1024')
cell_source = cell_source.replace('download=True', 'download=False')
if notebook_name == 'migrate_from_ptl':
cell_source = cell_source.replace('batch_size=256', 'batch_size=64')
cell_source = cell_source.replace('download=True', 'download=False')
+ cell_source = cell_source.replace("pip install 'mosaicml", f"pip install '{package_name}")
+ cell_source = cell_source.replace('pip install mosaicml', f'pip install {package_name}')
+
return cell_source
@@ -122,7 +137,7 @@ def test_notebook(notebook: str, device: str, s3_bucket: str):
obj = urlparse('s3://mosaicml-internal-integration-testing/read_only/CIFAR-10/')
s3 = boto3.resource('s3')
- bucket = s3.Bucket(obj.netloc)
+ bucket = s3.Bucket(obj.netloc) # pyright: ignore[reportGeneralTypeIssues]
files = bucket.objects.filter(Prefix=obj.path.lstrip('/'))
for file in files:
target = os.path.join(os.getcwd(), 'data', os.path.relpath(file.key, obj.path.lstrip('/')))
diff --git a/tests/test_precision.py b/tests/test_precision.py
index 46571529c6..2b85d3d7d2 100644
--- a/tests/test_precision.py
+++ b/tests/test_precision.py
@@ -9,8 +9,7 @@
from composer import Trainer
from composer.core import Precision, get_precision_context
-from composer.models import composer_resnet_cifar
-from tests.common import RandomImageDataset
+from tests.common import RandomImageDataset, composer_resnet
try:
import transformer_engine.pytorch as te
@@ -22,7 +21,7 @@
def get_trainer(precision: Precision, precision_config: Optional[Dict[str, Any]] = None) -> Trainer:
return Trainer(
- model=composer_resnet_cifar('resnet_9'),
+ model=composer_resnet('resnet18'),
train_dataloader=DataLoader(
dataset=RandomImageDataset(size=1024),
batch_size=512,
@@ -78,7 +77,7 @@ def predict_and_measure_memory(precision) -> int:
def test_train_precision_memory(precision: Precision):
memory_fp32 = fit_and_measure_memory(Precision.FP32)
memory_half = fit_and_measure_memory(precision)
- assert memory_half < 0.7 * memory_fp32
+ assert memory_half < 0.85 * memory_fp32
@pytest.mark.gpu
diff --git a/tests/test_simple_nlp.py b/tests/test_simple_nlp.py
index b200e7cfa5..6b53b16125 100644
--- a/tests/test_simple_nlp.py
+++ b/tests/test_simple_nlp.py
@@ -47,6 +47,7 @@ def test_simple_nlp_classification():
trainer.eval()
# Check that there is some train/eval accuracy
+ assert trainer.state.train_metrics is not None
assert trainer.state.train_metrics['MulticlassAccuracy'].compute() != 0.0
assert trainer.state.eval_metrics['eval']['MulticlassAccuracy'].compute() != 0.0
@@ -100,6 +101,7 @@ def test_simple_nlp_mlm(tiny_bert_tokenizer, tiny_bert_model):
trainer.eval()
# Check that there is some train/eval cross entropy
+ assert trainer.state.train_metrics is not None
assert trainer.state.train_metrics['LanguageCrossEntropy'].compute() != 0.0
assert trainer.state.eval_metrics['eval']['LanguageCrossEntropy'].compute() != 0.0
diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py
index 883fb04fb5..77b580b476 100644
--- a/tests/trainer/test_checkpoint.py
+++ b/tests/trainer/test_checkpoint.py
@@ -65,6 +65,7 @@ def _load_checkpoint(filename: Union[str, pathlib.Path]):
def _assert_checkpoints_equivalent(file1, file2, atol=0.0, rtol=0.0):
+ # TODO: consider merging with _assert_checkpoints_equal
checkpoint_1 = _load_checkpoint(file1)
checkpoint_2 = _load_checkpoint(file2)
@@ -85,6 +86,10 @@ def _assert_checkpoints_equivalent(file1, file2, atol=0.0, rtol=0.0):
if 'DummyStatefulCallback' in ckpt['state']['callbacks']:
del ckpt['state']['callbacks']['DummyStatefulCallback']
+ # Remove all saved checkpoints to timestamp (accumulates between runs)
+ del checkpoint_1['state']['callbacks']['CheckpointSaver']['all_saved_checkpoints_to_timestamp']
+ del checkpoint_2['state']['callbacks']['CheckpointSaver']['all_saved_checkpoints_to_timestamp']
+
deep_compare(checkpoint_1, checkpoint_2, atol=atol, rtol=rtol)
# deepspeed checkpoints do not have model or optimizer
@@ -280,6 +285,7 @@ def test_checkpoint_saver_properly_constructed(self, save_folder: str, expected_
'weights_only': False,
'save_interval': '1ep',
'num_checkpoints_to_keep': -1,
+ 'ignore_keys': None,
}
expected_folder = expected_path.rstrip('/') if expected_path != '' else '.'
mock_checkpoint_saver.assert_called_once_with(folder=expected_folder, **rest_of_checkpoint_saver_kwargs)
@@ -689,7 +695,11 @@ def test_strict_errors(self, missing_key: bool, unexpected_key: bool):
last_checkpoint = os.path.join('first', 'ep2.pt')
if missing_key or unexpected_key:
- error_context = pytest.raises(RuntimeError, match='Failed to load checkpoint due to')
+ message = r'Error\(s\) in loading state_dict'
+ if version.parse(torch.__version__) < version.parse('2.2.9'):
+ # Composer implements strict for older torch versions
+ message = 'Failed to load checkpoint due to'
+ error_context = pytest.raises(RuntimeError, match=message)
else:
error_context = contextlib.nullcontext()
@@ -741,6 +751,7 @@ def test_load_weights(self, device, load_weights_only, save_metrics):
assert metrics_equal
@pytest.mark.parametrize('load_ignore_keys,weights_equal,callbacks_equal,rng_equal', [
+ ['*', False, False, False],
['state/model/*', False, True, True],
['state/callbacks/*', True, False, True],
['rng', True, True, False],
@@ -780,6 +791,44 @@ def test_load_ignore_keys(self, load_ignore_keys, weights_equal, callbacks_equal
assert trainer_1_rng_state is not None
deep_compare(trainer_1_rng_state, trainer_2._rng_state)
+ @pytest.mark.parametrize('save_ignore_keys,weights_equal,callbacks_equal,rng_equal', [
+ ['*', False, False, False],
+ ['state/model/*', False, True, True],
+ ['state/callbacks/*', True, False, True],
+ ['rng', True, True, False],
+ ])
+ @pytest.mark.filterwarnings('ignore:.* is not in the state_dict.*:UserWarning')
+ def test_save_ignore_keys(self, save_ignore_keys, weights_equal, callbacks_equal, rng_equal):
+
+ trainer_1 = self.get_trainer(save_folder='first', save_ignore_keys=[save_ignore_keys])
+ trainer_1.fit()
+ trainer_1_rng_state = reproducibility.get_rng_state()
+ trainer_1.close()
+
+ last_checkpoint = os.path.join('first', 'ep2.pt')
+ trainer_2 = self.get_trainer(load_path=last_checkpoint)
+
+ # Check weights loaded properly
+ with contextlib.nullcontext() if weights_equal else pytest.raises(AssertionError):
+ self._assert_weights_equivalent(
+ trainer_1.state.model,
+ trainer_2.state.model,
+ )
+
+ # Check callbacks state
+ stateful_callbacks_equal = self._stateful_callbacks_equal(
+ trainer_1.state.callbacks,
+ trainer_2.state.callbacks,
+ )
+ if callbacks_equal:
+ assert stateful_callbacks_equal
+ else:
+ assert not stateful_callbacks_equal
+
+ if rng_equal:
+ assert trainer_1_rng_state is not None
+ deep_compare(trainer_1_rng_state, trainer_2._rng_state)
+
@pytest.mark.remote
@device('cpu')
@pytest.mark.parametrize('load_weights_only', [True, False])
@@ -790,8 +839,6 @@ def test_load_ignore_keys(self, load_ignore_keys, weights_equal, callbacks_equal
],
)
@pytest.mark.filterwarnings('ignore:.*The checkpoint included CUDA RNG state.*')
- @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='requires PyTorch 1.13 or higher')
def test_load_remote_checkpoint(self, device, tmp_path: pathlib.Path, load_weights_only, remote_checkpoint_uri,
remote_checkpoint_name, continue_training_dur, final_checkpoint_name, s3_bucket,
s3_read_only_prefix):
@@ -972,8 +1019,10 @@ def test_autoload_algorithm_old_checkpoint(self):
old_init, old_repr = NoOpModel.__init__, NoOpModel.__repr__
NoOpModel.__init__ = lambda self, x: None # type: ignore
NoOpModel.__repr__ = lambda self: 'NoOpModel(3)'
- with pytest.warns(UserWarning, match='required_on_load algorithm.*'), pytest.raises(
- ValueError, match='loaded state dict contains a parameter group.*'):
+ error_context = pytest.raises(KeyError, match='module.0.weight')
+ if version.parse(torch.__version__) < version.parse('2.2.9'):
+ error_context = pytest.raises(ValueError, match='loaded state dict contains a parameter group.*')
+ with pytest.warns(UserWarning, match='required_on_load algorithm.*'), error_context:
trainer_3 = self.get_trainer(load_path=os.path.join('first', 'ep1.pt'),)
trainer_3.fit(duration='1ba')
# Restore algorithm
@@ -1247,6 +1296,36 @@ def test_spin_dataloaders(
save_folder / 'second' / 'latest-rank{rank}.pt',
)
+ def test_format_load_path(self, tmp_path: pathlib.Path):
+ run_name = 'a-quick-rabbit'
+ save_folder = os.path.join(tmp_path, '{run_name}')
+ trainer = self.get_trainer(
+ run_name=run_name,
+ save_folder=os.path.join(save_folder, 'first'),
+ save_filename='ep{epoch}-rank{rank}.pt',
+ save_interval='1ep',
+ )
+
+ trainer.fit()
+ trainer.close()
+
+ resume_file = os.path.join(save_folder, 'first', 'ep1-rank0.pt')
+ trainer = self.get_trainer(
+ run_name=run_name,
+ save_folder=os.path.join(save_folder, 'second'),
+ save_filename='ep{epoch}-rank{rank}.pt',
+ save_interval='1ep',
+ load_path=resume_file, # <-- resume training from file
+ )
+ trainer.fit()
+ trainer.close()
+
+ save_folder = save_folder.replace('{run_name}', run_name)
+ _assert_checkpoints_equivalent(
+ os.path.join(save_folder, 'first', 'latest-rank{rank}.pt'),
+ os.path.join(save_folder, 'second', 'latest-rank{rank}.pt'),
+ )
+
def _assert_expected_num_checkpoints(
self,
save_folder: str,
@@ -1306,6 +1385,7 @@ def test_rotate_checkpoints(
dataset=train_dataset,
sampler=dist.get_sampler(train_dataset),
),
+ precision='fp32',
save_folder=str(save_folder),
save_filename='checkpoint_{rank}_{batch}.pt',
save_interval='1ba',
diff --git a/tests/trainer/test_ddp.py b/tests/trainer/test_ddp.py
index f34ba3862d..d9733c4285 100644
--- a/tests/trainer/test_ddp.py
+++ b/tests/trainer/test_ddp.py
@@ -7,17 +7,15 @@
import pytest
import torch
import torch.distributed
-from packaging import version
from torch.utils.data import DataLoader
import composer.core.types as types
from composer import Callback, Event
from composer.core import State
-from composer.datasets.synthetic import SyntheticBatchPairDataset
from composer.loggers import Logger
from composer.trainer.trainer import Trainer
from composer.utils import dist
-from tests.common import SimpleModel
+from tests.common import RandomClassificationDataset, SimpleModel
def get_file_path(*, is_train: bool, tmp_path: pathlib.Path) -> str:
@@ -41,8 +39,8 @@ class TrackedDataset(types.Dataset):
atomic file writes, it is slow and should not be used in any performance measurements.
"""
- def __init__(self, is_train: bool, synthetic_dataset: SyntheticBatchPairDataset, tmp_path: pathlib.Path):
- self.dataset = synthetic_dataset
+ def __init__(self, is_train: bool, dataset, tmp_path: pathlib.Path):
+ self.dataset = dataset
self.is_train = is_train
self.tmp_path = tmp_path
self.counter = 0
@@ -88,16 +86,11 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
pytest.param('gpu', False, False, id='gpu', marks=pytest.mark.gpu),
# TODO: Remove filterwarnings after FSDP removes deprecated code
pytest.param('gpu', True, False, id='deepspeed', marks=pytest.mark.gpu),
- pytest.param('gpu',
- False,
- True,
- id='fsdp',
- marks=[
- pytest.mark.gpu,
- pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='requires PyTorch 1.13 or higher'),
- pytest.mark.filterwarnings('ignore::UserWarning'),
- ]),
+ pytest.param(
+ 'gpu', False, True, id='fsdp', marks=[
+ pytest.mark.gpu,
+ pytest.mark.filterwarnings('ignore::UserWarning'),
+ ]),
])
@pytest.mark.parametrize('world_size', [
pytest.param(1),
@@ -116,19 +109,11 @@ def test_ddp(device: str, world_size: int, deepspeed: bool, fsdp: bool, tmp_path
and 2) each ddp process is indeed getting different data.
"""
- model = SimpleModel(num_classes=100)
-
train_batch_size = 10
train_subset_num_batches = 3
- synthetic_dataset = SyntheticBatchPairDataset(
- num_unique_samples_to_create=train_batch_size * train_subset_num_batches,
- total_dataset_size=10_000,
- data_shape=(model.num_features, 5, 5),
- num_classes=model.num_classes,
- )
train_dataset = TrackedDataset(
- synthetic_dataset=synthetic_dataset,
+ dataset=RandomClassificationDataset(size=train_batch_size * train_subset_num_batches,),
is_train=True,
tmp_path=tmp_path,
)
@@ -150,14 +135,8 @@ def test_ddp(device: str, world_size: int, deepspeed: bool, fsdp: bool, tmp_path
eval_batch_size = 10
eval_subset_num_batches = 3
- eval_dataset = SyntheticBatchPairDataset(
- num_unique_samples_to_create=eval_batch_size * eval_subset_num_batches,
- total_dataset_size=10_000,
- data_shape=(model.num_features, 5, 5),
- num_classes=model.num_classes,
- )
eval_dataset = TrackedDataset(
- synthetic_dataset=eval_dataset,
+ dataset=RandomClassificationDataset(size=eval_batch_size * eval_subset_num_batches,),
is_train=False,
tmp_path=tmp_path,
)
@@ -185,17 +164,19 @@ def test_ddp(device: str, world_size: int, deepspeed: bool, fsdp: bool, tmp_path
}
max_epochs = 2
- trainer = Trainer(model=model,
- train_dataloader=train_dataloader,
- eval_dataloader=eval_dataloader,
- device=device,
- max_duration=f'{max_epochs}ep',
- eval_interval='1ep',
- eval_subset_num_batches=eval_subset_num_batches,
- train_subset_num_batches=train_subset_num_batches,
- deepspeed_config={} if deepspeed else None,
- fsdp_config=fsdp_config,
- callbacks=[CheckBatch0(tmp_path)])
+ trainer = Trainer(
+ model=SimpleModel(num_classes=100),
+ train_dataloader=train_dataloader,
+ eval_dataloader=eval_dataloader,
+ device=device,
+ max_duration=f'{max_epochs}ep',
+ eval_interval='1ep',
+ eval_subset_num_batches=eval_subset_num_batches,
+ train_subset_num_batches=train_subset_num_batches,
+ deepspeed_config={} if deepspeed else None,
+ fsdp_config=fsdp_config,
+ callbacks=[CheckBatch0(tmp_path)],
+ )
trainer.fit()
diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py
index 95aaf31e97..c6f5258c49 100644
--- a/tests/trainer/test_fsdp.py
+++ b/tests/trainer/test_fsdp.py
@@ -26,8 +26,6 @@
@world_size(2)
@pytest.mark.gpu
@pytest.mark.filterwarnings('ignore:The passed in model appears to have tied weights.*:UserWarning')
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='FSDP requires PyTorch 1.13 or higher')
def test_fsdp_device_initialization(model: ComposerClassifier, mixed_precision: str, reentrant: bool, world_size: int,
device: str):
"""test FSDP device initialization for a simple model with weight tying and a model where two modules
@@ -99,11 +97,11 @@ def test_fsdp_inits_params_once(model: ComposerClassifier, device: str, world_si
def dummy_param_init_fn(module: torch.nn.Module):
if isinstance(module, torch.nn.Linear):
torch.nn.init.ones_(module.weight)
- if module.bias is not None:
+ if module.bias is not None: # pyright: ignore[reportUnnecessaryComparison]
torch.nn.init.constant_(module.bias, 2)
# Override the param_init_fn to be deterministic so we can test the init
- model.module.param_init_fn = dummy_param_init_fn
+ model.module.param_init_fn = dummy_param_init_fn # pyright: ignore[reportGeneralTypeIssues]
# Apply the initial initialization, because it will only be called later for parameters on meta device
model.apply(model.module.param_init_fn)
# Now wrap the param_init_fn with a MagicMock so we can count calls
@@ -136,7 +134,7 @@ def dummy_param_init_fn(module: torch.nn.Module):
for module in model.modules():
if isinstance(module, torch.nn.Linear):
assert torch.all(module.weight == 1)
- if module.bias is not None:
+ if module.bias is not None: # pyright: ignore[reportUnnecessaryComparison]
assert torch.all(module.bias == 2)
@@ -144,8 +142,6 @@ def dummy_param_init_fn(module: torch.nn.Module):
@pytest.mark.parametrize('mixed_precision', _MIXED_PRECISION_TYPES)
@pytest.mark.gpu
@world_size(2)
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='FSDP requires PyTorch 1.13 or higher')
def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precision: 'str', world_size: int):
"""
This test is intended to test FSDP for meta initialization when there are attributes
@@ -173,12 +169,10 @@ def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precisio
@pytest.mark.parametrize('backward_prefetch_limit', [1, 2])
@pytest.mark.gpu
@world_size(2)
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='FSDP requires PyTorch 1.13 or higher')
def test_fsdp_prefetch_limit(forward_prefetch_limit: int, backward_prefetch_limit: int, world_size: int):
model = SimpleModel()
- model.fc1._fsdp_wrap = True
- model.fc2._fsdp_wrap = True
+ model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
+ model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
dataset = RandomClassificationDataset(size=10)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
@@ -199,14 +193,12 @@ def test_fsdp_prefetch_limit(forward_prefetch_limit: int, backward_prefetch_limi
@pytest.mark.gpu
@world_size(2)
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='FSDP requires PyTorch 1.13 or higher')
@pytest.mark.filterwarnings('ignore:Instantiating FSDP with custom process groups.*:UserWarning')
@pytest.mark.filterwarnings('ignore:Composer is instantiating custom process groups.*:UserWarning')
def test_fsdp_process_group(world_size: int):
model = SimpleModel()
- model.fc1._fsdp_wrap = True
- model.fc2._fsdp_wrap = True
+ model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
+ model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
dataset = RandomClassificationDataset(size=10)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
@@ -258,7 +250,7 @@ def test_fsdp_act_ckpt_offload(
'activation_cpu_offload': activation_cpu_offload,
}
- model.fc1._activation_checkpointing = True
+ model.fc1._activation_checkpointing = True # pyright: ignore[reportGeneralTypeIssues]
trainer = Trainer(
model=model,
diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py
index d9b7c5b5ee..bda2a36187 100644
--- a/tests/trainer/test_fsdp_checkpoint.py
+++ b/tests/trainer/test_fsdp_checkpoint.py
@@ -11,7 +11,7 @@
import uuid
from contextlib import nullcontext as does_not_raise
from functools import partial
-from typing import Any, Callable, Optional, Sequence
+from typing import Any, Callable, Optional, Sequence, Union
from unittest.mock import patch
import numpy as np
@@ -28,10 +28,9 @@
from composer.models import ComposerClassifier
from composer.optim import DecoupledAdamW
from composer.trainer import Trainer
-from composer.utils import dist
+from composer.utils import dist, parse_uri
from composer.utils.checkpoint import is_checkpoint_legacy_sharded
from composer.utils.file_helpers import get_file
-from composer.utils.misc import using_torch_2
from composer.utils.object_store import S3ObjectStore
from composer.utils.reproducibility import get_rng_state
from tests.common import RandomClassificationDataset, deep_compare
@@ -58,9 +57,9 @@ def __init__(
for module in net:
if isinstance(module, torch.nn.Linear):
- module._fsdp_wrap = True
+ module._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
- net.param_init_fn = self.param_init_fn
+ net.param_init_fn = self.param_init_fn # pyright: ignore[reportGeneralTypeIssues]
super().__init__(
module=net,
num_classes=num_classes,
@@ -73,7 +72,7 @@ def param_init_fn(self, module):
if isinstance(module, torch.nn.Linear):
init_fn(module.weight)
- if module.bias is not None:
+ if module.bias is not None: # pyright: ignore[reportUnnecessaryComparison]
torch.nn.init.zeros_(module.bias)
@@ -238,7 +237,8 @@ def _compare_rng_states_between_trainers(rng_state1, rng_state2):
if 'cuda' in rank_state1_keys:
cuda_state1 = rank_state1['cuda']
cuda_state2 = rank_state2['cuda']
- torch.equal(cuda_state1, cuda_state2), f'Cuda rng state not the same between state_dicts for rank {rank}'
+ states_equal = torch.equal(cuda_state1, cuda_state2)
+ assert states_equal, f'Cuda rng state not the same between state_dicts for rank {rank}'
def _compare_metrics_between_state_dicts(state_dict1: dict[str, Any], state_dict2: dict[str, Any]):
@@ -274,8 +274,6 @@ def _compare_timestamps_between_state_dicts(state_dict1, state_dict2):
@pytest.mark.parametrize('autoresume', [True, False])
@pytest.mark.parametrize('precision', ['amp_bf16', 'amp_fp16'])
@pytest.mark.parametrize('load_fsdp_monolith_rank0_only', [True, False])
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='requires PyTorch 1.13 or higher')
def test_fsdp_full_state_dict_load(
world_size,
tmp_path: pathlib.Path,
@@ -341,8 +339,6 @@ def test_fsdp_full_state_dict_load(
@pytest.mark.gpu
@world_size(2)
@pytest.mark.parametrize('sync_module_states', [True, False])
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='requires PyTorch 1.13 or higher')
def test_fsdp_mixed_with_sync(
world_size,
tmp_path: pathlib.Path,
@@ -365,7 +361,7 @@ def test_fsdp_mixed_with_sync(
@world_size(2)
@pytest.mark.parametrize('precision', ['amp_bf16', 'amp_fp16'])
@pytest.mark.parametrize('sharding_strategy', ['FULL_SHARD', 'SHARD_GRAD_OP'])
-@pytest.mark.parametrize('state_dict_type', ['full', 'sharded', 'local'])
+@pytest.mark.parametrize('state_dict_type', ['full', 'sharded'])
@pytest.mark.parametrize('composer_version', [
pytest.param(
'0.13.5',
@@ -400,12 +396,11 @@ def test_fsdp_mixed_with_sync(
'0.17.0',
marks=pytest.mark.filterwarnings((r'ignore:MosaicMLLogger is not in the state_dict. Its '
r'state will not be restored.:UserWarning')),
- )
+ ),
+ '0.18.1',
])
@pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*The CUDA RNG state could not be loaded.*:UserWarning')
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='requires PyTorch 1.13 or higher')
def test_fsdp_load_old_checkpoint(
world_size,
tmp_path: pathlib.Path,
@@ -416,15 +411,8 @@ def test_fsdp_load_old_checkpoint(
s3_read_only_prefix: str,
composer_version: str,
):
-
- if (version.parse(torch.__version__) >= version.parse('1.13.0') and
- composer_version not in ['0.13.5', '0.14.0', '0.14.1']):
- pytest.skip(('Composer 0.15.1 and above checkpoints were saved with '
- 'torch 2 and as a result are not compatible with torch 1.13.'))
- if (version.parse(torch.__version__) >= version.parse('2.0.0') and state_dict_type == 'local'):
- pytest.xfail(('Loading a torch 1.13 checkpoint with torch 2.0 for '
- 'state_dict_type local is not backwards compatible. See '
- 'https://github.com/pytorch/pytorch/issues/102667 for more info'))
+ if composer_version == '0.18.1' and state_dict_type == 'full' and precision == 'amp_bf16' and sharding_strategy == 'FULL_SHARD':
+ pytest.skip('TODO: This checkpoint is missing')
if composer_version in ['0.13.5', '0.14.0', '0.14.1', '0.15.1']:
rank = 0 if state_dict_type == 'full' else '{rank}'
@@ -436,7 +424,6 @@ def test_fsdp_load_old_checkpoint(
load_path_dir = (load_path_dir + 'ep0-ba2/')
load_path = load_path_dir + f'ba2_rank{rank}.pt'
-
assert is_checkpoint_legacy_sharded(
object_store=S3ObjectStore(bucket=f'{s3_bucket}'),
source_path=load_path.lstrip(f's3://{s3_bucket}/'),
@@ -445,6 +432,10 @@ def test_fsdp_load_old_checkpoint(
load_path = (f's3://{s3_bucket}/{s3_read_only_prefix}/backwards_compatibility/'
f'{composer_version}/{sharding_strategy.lower()}_{state_dict_type}_'
f'{precision}/')
+ if state_dict_type == 'full':
+ load_path += 'ba2_rank0.pt'
+ else:
+ load_path += 'ep0-ba2/'
if composer_version == '0.15.1':
num_classes = 8 # This parameter setting is very important. Don't change or the test will fail.
@@ -475,14 +466,81 @@ def test_fsdp_load_old_checkpoint(
)
state_dict2 = trainer.state.state_dict()
- if ((dist.get_global_rank() == 0 and state_dict_type == 'full') or state_dict_type in ['sharded', 'local']):
- filled_load_path = load_path.format(rank=dist.get_global_rank())
- destination = str(tmp_path / pathlib.Path(filled_load_path).name)
- get_file(filled_load_path, destination=destination)
- with open(destination, 'rb') as f:
- state_dict1 = torch.load(f)['state']
- _compare_model_params_between_state_dicts(state_dict1, state_dict2)
+ if (dist.get_global_rank() == 0 and state_dict_type == 'full') or state_dict_type == 'sharded':
+ # After composer version 0.16.0, sharded checkpoints are of type folder/__{local_rank}__{global_rank}.distcp
+ # They cannot be loaded with `get_file` as we need the whole folder to load the checkpoint.
+ # Thus, we use the DistCPObjectStoreReader to load the state_dict.
+ if state_dict_type == 'sharded' and version.parse(composer_version) >= version.parse('0.16.0'):
+ trainer2 = get_trainer(
+ num_features=32, # This parameter setting is very important. Don't change or the test will fail.
+ num_classes=8, # This parameter setting is very important. Don't change or the test will fail.
+ precision=precision,
+ max_duration='10ba', # Change this so we have slightly different model runtime settings.
+ train_metrics=train_metrics,
+ val_metrics=val_metrics,
+ fsdp_config=fsdp_config,
+ )
+ from torch.distributed import checkpoint as dist_cp
+
+ from composer.utils.checkpoint import DistCPObjectStoreReader
+
+ _, _, parsed_load_path = parse_uri(load_path)
+ gathered_tmp_path = str(dist.all_gather_object(tmp_path)[0])
+ destination = str(pathlib.Path(gathered_tmp_path) / parsed_load_path)
+ state_dict: dict[str, Any] = {
+ 'state': trainer2.state.state_dict(),
+ 'rng': get_rng_state(),
+ }
+ if version.parse(torch.__version__) < version.parse('2.2.9'):
+ state_dict['state'].pop('optimizers')
+
+ object_store = S3ObjectStore(bucket=f'{s3_bucket}')
+ storage_reader = DistCPObjectStoreReader(source_path=parsed_load_path,
+ destination_path=destination,
+ object_store=object_store,
+ device_mesh=None)
+
+ process_group = None
+ dist_cp.load_state_dict(
+ state_dict=state_dict,
+ storage_reader=storage_reader,
+ planner=None,
+ process_group=process_group,
+ )
+ if version.parse(torch.__version__) < version.parse('2.2.9'):
+ from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+ model_state_dict = state_dict['state']['model']
+ model = trainer2.state.model
+ optim = trainer2.state.optimizers[0]
+ optim_name = type(optim).__qualname__
+ optim_state_dict = load_sharded_optimizer_state_dict(model_state_dict=model_state_dict,
+ optimizer_key='optimizers',
+ storage_reader=storage_reader)
+ with fsdp_state_dict_type_context(module=model, state_dict_type=state_dict_type):
+ optim_state_dict = FSDP.optim_state_dict_to_load(
+ optim_state_dict=optim_state_dict['optimizers'][optim_name], model=model, optim=optim)
+
+ trainer2.state.optimizers[0].load_state_dict(optim_state_dict)
+
+ with fsdp_state_dict_type_context(module=model, state_dict_type=state_dict_type):
+ flattened_optim_state_dict = FSDP.optim_state_dict(model, optim) # type: ignore
+
+ state_dict['state']['optimizers'] = {
+ optim_name: flattened_optim_state_dict,
+ }
+
+ state_dict1 = state_dict['state']
+ else:
+ filled_load_path = load_path.format(rank=dist.get_global_rank())
+ destination = str(tmp_path / pathlib.Path(filled_load_path).name)
+
+ get_file(filled_load_path, destination=destination)
+ with open(destination, 'rb') as f:
+ state_dict1 = torch.load(f)['state']
+
+ _compare_model_params_between_state_dicts(state_dict1, state_dict2)
_compare_optims_between_state_dicts(state_dict1, state_dict2)
# Continue to fit to make sure we can continue training.
@@ -494,8 +552,6 @@ def test_fsdp_load_old_checkpoint(
@world_size(2)
@pytest.mark.parametrize('optimizer', ['adam', 'adamw'])
@pytest.mark.parametrize('precision', ['amp_bf16', 'amp_fp16'])
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='requires PyTorch 1.13 or higher')
def test_fsdp_full_state_dict_load_with_ema(
world_size,
tmp_path: pathlib.Path,
@@ -551,8 +607,6 @@ def test_fsdp_full_state_dict_load_with_ema(
@world_size(2)
@pytest.mark.parametrize('is_valid_checkpoint', [True, False])
@pytest.mark.parametrize('state_dict_type', ['sharded', 'full'])
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='requires PyTorch 1.13 or higher')
@pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:Please use DTensor instead and we are deprecating ShardedTensor.:UserWarning')
@@ -560,11 +614,7 @@ def test_checkpoint_loading_with_validation(world_size, tmp_path, is_valid_check
# Set the error expectations.
expectation = does_not_raise()
if not is_valid_checkpoint:
- if using_torch_2() and state_dict_type == 'sharded':
- from torch.distributed.checkpoint import CheckpointException
- expectation = pytest.raises(CheckpointException)
- else:
- expectation = pytest.raises(ValueError)
+ expectation = pytest.raises(ValueError)
def mock_get_checkpoint_validation_function():
return lambda _: is_valid_checkpoint
@@ -581,10 +631,7 @@ def mock_get_checkpoint_validation_function():
# Determine the checkpoint path for loading.
checkpoint_relpath = 'ba1-rank0.pt'
if state_dict_type == 'sharded':
- if using_torch_2():
- checkpoint_relpath = 'ba1'
- else:
- checkpoint_relpath = 'ba1/ba1-rank{rank}.pt'
+ checkpoint_relpath = 'ba1'
# Load checkpoints with checkpoint validation.
with expectation:
@@ -599,25 +646,26 @@ def mock_get_checkpoint_validation_function():
@pytest.mark.gpu
@world_size(2)
-@pytest.mark.parametrize('weights_only', [False, True])
-@pytest.mark.parametrize('optimizer', ['adam', 'adamw'])
-@pytest.mark.parametrize('state_dict_type', ['sharded', 'local'])
-@pytest.mark.parametrize('precision', ['amp_bf16', 'amp_fp16'])
@pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False])
-@pytest.mark.parametrize('autoresume', [True, False])
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='requires PyTorch 1.13 or higher')
+@pytest.mark.parametrize('weights_only,optimizer,precision,autoresume,load_ignore_keys', [
+ [False, 'adamw', 'amp_bf16', False, None],
+ [True, 'adamw', 'amp_bf16', False, None],
+ [False, 'adam', 'amp_bf16', False, None],
+ [False, 'adamw', 'amp_fp16', False, None],
+ [False, 'adamw', 'amp_bf16', True, None],
+ [False, 'adamw', 'amp_bf16', False, ['rng']],
+])
@pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:Please use DTensor instead and we are deprecating ShardedTensor.:UserWarning')
def test_fsdp_partitioned_state_dict_load(
world_size,
tmp_path: pathlib.Path,
- state_dict_type: str,
autoresume: bool,
precision: str,
optimizer: str,
weights_only: bool,
+ load_ignore_keys: Union[list[str], None],
use_remote,
s3_bucket,
s3_ephemeral_prefix,
@@ -625,10 +673,7 @@ def test_fsdp_partitioned_state_dict_load(
):
if weights_only and autoresume:
pytest.xfail('Weights only with autoresume is not supported')
- if state_dict_type == 'local' and using_torch_2():
- pytest.xfail(('Loading a state_dict_type="local" checkpoint with strict=True '
- 'errors out. See https://github.com/pytorch/pytorch/issues/102667 '
- 'for more info'))
+ load_ignore_keys = [] if load_ignore_keys is None else load_ignore_keys
if autoresume:
local_run_name = f'my-cool-autoresume-run-{uuid.uuid1()}'
@@ -644,7 +689,7 @@ def test_fsdp_partitioned_state_dict_load(
save_filename = 'ba{batch}-rank{rank}.pt'
- fsdp_config = FSDPConfig(state_dict_type=state_dict_type)
+ fsdp_config = FSDPConfig(state_dict_type='sharded')
trainer1 = get_trainer(
save_folder=str(save_folder),
@@ -671,19 +716,10 @@ def test_fsdp_partitioned_state_dict_load(
object_store = None
load_path = str(save_folder.format(run_name=run_name) / pathlib.Path('ba2'))
- if not using_torch_2():
- load_filename = f"{save_filename.format(batch=2, rank='{rank}')}"
- assert load_filename == 'ba2-rank{rank}.pt'
- load_path += '/' + load_filename
- assert is_checkpoint_legacy_sharded(
- object_store=object_store,
- source_path=load_path.replace(f's3://{s3_bucket}/', ''),
- )
- else:
- assert not is_checkpoint_legacy_sharded(
- object_store=object_store,
- source_path=load_path.replace(f's3://{s3_bucket}/', ''),
- )
+ assert not is_checkpoint_legacy_sharded(
+ object_store=object_store,
+ source_path=load_path.replace(f's3://{s3_bucket}/', ''),
+ )
if autoresume:
load_path = None
@@ -699,6 +735,7 @@ def test_fsdp_partitioned_state_dict_load(
optimizer=optimizer,
load_weights_only=weights_only,
fsdp_config=fsdp_config,
+ load_ignore_keys=load_ignore_keys,
)
state_dict_from_trainer2 = trainer2.state.state_dict()
rng2 = trainer2._rng_state
@@ -708,7 +745,10 @@ def test_fsdp_partitioned_state_dict_load(
state_dict_from_trainer2,
)
if not weights_only:
- _compare_rng_states_between_trainers(rng1, rng2)
+ if any('rng' in x for x in load_ignore_keys):
+ assert rng1 is not None and rng2 is None
+ else:
+ _compare_rng_states_between_trainers(rng1, rng2)
_compare_optims_between_state_dicts(
state_dict_from_trainer1_ba2,
state_dict_from_trainer2,
@@ -729,20 +769,16 @@ def test_fsdp_partitioned_state_dict_load(
@pytest.mark.gpu
@pytest.mark.remote
@world_size(2)
-@pytest.mark.parametrize('state_dict_type', ['sharded'])
@pytest.mark.parametrize('precision', ['amp_bf16', 'amp_fp16'])
-@pytest.mark.parametrize('autoresume', [False, True]) # True commented out for now
+@pytest.mark.parametrize('autoresume', [False, True])
@pytest.mark.parametrize('num_shards', [2, 4, 7])
@pytest.mark.parametrize('sharding_strategy', ['FULL_SHARD', 'SHARD_GRAD_OP'])
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0.1'),
- reason='requires PyTorch 2.0.1 or higher')
@pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning')
@pytest.mark.filterwarnings(r'ignore:MosaicMLLogger is not in the state_dict.:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning')
def test_elastic_resumption(
world_size,
tmp_path: pathlib.Path,
- state_dict_type: str,
autoresume: bool,
precision: str,
sharding_strategy,
@@ -750,17 +786,13 @@ def test_elastic_resumption(
s3_read_only_prefix,
num_shards: int,
):
- if state_dict_type == 'local' and using_torch_2():
- pytest.xfail(('Loading a state_dict_type="local" checkpoint with '
- 'strict=True errors out. See https://github.com/pytorch/pytorch/issues/102667 '
- 'for more info'))
if autoresume:
run_name = 'my-autoresume-run'
else:
run_name = None
base_path = (f's3://{s3_bucket}/{s3_read_only_prefix}/elastic_test/'
- f'{sharding_strategy.lower()}_{state_dict_type}_{precision}_'
+ f'{sharding_strategy.lower()}_sharded_{precision}_'
f'{num_shards}/')
mono_load_path = os.path.join(base_path, 'mono.pt')
@@ -797,7 +829,7 @@ def test_elastic_resumption(
run_name=run_name,
max_duration='4ba',
load_weights_only=False,
- fsdp_config=FSDPConfig(state_dict_type=state_dict_type),
+ fsdp_config=FSDPConfig(state_dict_type='sharded'),
)
def get_mono_state_dict_from_sharded_one(trainer):
@@ -843,87 +875,20 @@ def compare_state_dicts():
@pytest.mark.gpu
@world_size(2)
-@pytest.mark.parametrize('state_dict_type', ['local', 'sharded'])
-@pytest.mark.parametrize('autoresume', [True])
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='requires PyTorch 1.13 or higher')
-@pytest.mark.skipif(version.parse(torch.__version__) > version.parse('1.13.0'),
- reason='All Pytorch 2.0 checkpoints have just 1 symlink')
-def test_mismatch_timestamp_error(
- world_size,
- tmp_path: pathlib.Path,
- state_dict_type: str,
- autoresume: bool,
-):
- run_name = 'my-run-ar' if autoresume else 'my-run'
- tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path))
- save_folder = str(tmp_paths[0] / pathlib.Path(run_name))
- save_filename = 'ba{batch}-rank{rank}.pt'
- trainer1 = get_trainer(
- save_folder=save_folder,
- save_filename=save_filename,
- run_name=run_name,
- autoresume=autoresume,
- max_duration='2ba',
- save_interval='1ba',
- fsdp_config=FSDPConfig(state_dict_type=state_dict_type),
- )
- trainer1.fit()
- trainer1.close()
- latest_symlink = str(pathlib.Path(save_folder) / pathlib.Path(f'latest-rank{dist.get_global_rank()}.pt'))
- latest_checkpoint_path = pathlib.Path(save_folder) / pathlib.Path('ba2') / (pathlib.Path(
- save_filename.format(batch=2, rank=dist.get_global_rank())) if not using_torch_2() else pathlib.Path(''))
- assert os.path.join(save_folder, os.readlink(latest_symlink)) == str(latest_checkpoint_path)
- oldest_checkpoint_relative_path = str(
- pathlib.Path('ba1') / (pathlib.Path(save_filename.format(batch=1, rank=dist.get_global_rank()))
- if not using_torch_2() else pathlib.Path('')))
-
- # Corrupt latest checkpoint symlink for rank1 by changing it from batch 2 checkpoint to the batch 1 one
- # and removing batch 2 checkpoint.
- if dist.get_global_rank() == 0:
- os.remove(latest_symlink)
- os.symlink(src=oldest_checkpoint_relative_path, dst=latest_symlink)
- assert os.readlink(latest_symlink) == oldest_checkpoint_relative_path
-
- dist.barrier()
- expected_error = pytest.raises(RuntimeError, match='Timestamp mismatch error:*')
-
- with expected_error:
- get_trainer(
- save_folder=save_folder,
- save_filename=save_filename,
- autoresume=autoresume,
- run_name=run_name,
- fsdp_config=FSDPConfig(state_dict_type=state_dict_type),
- )
-
-
-@pytest.mark.gpu
-@world_size(2)
-@pytest.mark.parametrize('state_dict_type', ['sharded', 'local'])
@pytest.mark.parametrize('num_ckpts_to_keep', [-1, 1, 2, 3])
-@pytest.mark.parametrize('batches_to_train', [3])
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='requires PyTorch 1.13 or higher')
@pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning')
@pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning')
@pytest.mark.filterwarnings(r'ignore:Please use DTensor instead and we are deprecating ShardedTensor.:UserWarning')
def test_cleanup_sharded_checkpoints(
world_size,
tmp_path: pathlib.Path,
- state_dict_type: str,
num_ckpts_to_keep: int,
- batches_to_train: int,
s3_bucket,
s3_ephemeral_prefix,
request,
):
- if state_dict_type == 'local' and using_torch_2():
- pytest.xfail(('Loading a state_dict_type="local" checkpoint with strict=True '
- 'errors out. See https://github.com/pytorch/pytorch/issues/102667 '
- 'for more info'))
-
run_name = None
+ batches_to_train = 3
tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path))
save_folder = os.path.join(tmp_paths[0], 'checkpoints', '{run_name}')
@@ -936,7 +901,7 @@ def test_cleanup_sharded_checkpoints(
max_duration=f'{batches_to_train}ba',
save_interval='1ba',
save_num_checkpoints_to_keep=num_ckpts_to_keep,
- fsdp_config=FSDPConfig(state_dict_type=state_dict_type),
+ fsdp_config=FSDPConfig(state_dict_type='sharded'),
)
run_name = trainer1.state.run_name
trainer1.fit()
@@ -950,9 +915,7 @@ def test_cleanup_sharded_checkpoints(
assert num_checkpoint_dirs == num_ckpts_to_keep
for ckpt_dir in dir_contents:
full_path_ckpt_dir = os.path.join(shards_dir, ckpt_dir)
- elastic_file_list = {'.metadata', *[f'__{rank}_0.distcp' for rank in range(dist.get_world_size())]}
- non_elastic_file_list = {save_filename.format(rank=rank) for rank in range(dist.get_world_size())}
- file_list = elastic_file_list if using_torch_2() else non_elastic_file_list
+ file_list = {'.metadata', *[f'__{rank}_0.distcp' for rank in range(dist.get_world_size())]}
assert set(os.listdir(full_path_ckpt_dir)) == file_list
diff --git a/tests/trainer/test_fsdp_param_groups.py b/tests/trainer/test_fsdp_param_groups.py
index a144db51a4..30e29b4de5 100644
--- a/tests/trainer/test_fsdp_param_groups.py
+++ b/tests/trainer/test_fsdp_param_groups.py
@@ -19,26 +19,19 @@
@pytest.mark.filterwarnings('ignore::UserWarning')
@device('gpu')
@world_size(2)
-@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2'),
- reason='FSDP use_orig_params requires torch 2.0 or higher')
def test_fsdp_param_groups_without_orig_params(mixed_precision: str, device: str, reentrant: bool, world_size: int):
- """
-
- Ensure that FSDP with 'use_orig_params=False' raises an exception when passing in an optimizer
- with multiple param groups
-
- """
+ # Ensure that FSDP with 'use_orig_params=False' raises an exception when passing in an optimizer
+ # with multiple param groups
num_classes = 10
model = SimpleModel(num_features=1, num_classes=num_classes)
dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))
- # create a different parameter per group
+ # Create a different parameter per group
param_groups = [{'params': param, 'lr': (0.1 + 0.1 * i)} for i, param in enumerate(model.parameters())]
optimizer = torch.optim.SGD(param_groups, lr=0)
- expected_error = 'Multiple optimizer groups with FSDP are only supported on torch 2.0 \
- with use_orig_params=True.'
+ expected_error = 'Multiple optimizer groups with FSDP are only supported with use_orig_params=True.'
with pytest.raises(RuntimeError, match=expected_error):
_ = Trainer(model=model,
diff --git a/tests/trainer/test_scale_schedule.py b/tests/trainer/test_scale_schedule.py
index 2ae600f70c..ec90890e07 100644
--- a/tests/trainer/test_scale_schedule.py
+++ b/tests/trainer/test_scale_schedule.py
@@ -7,11 +7,11 @@
import pytest
import torch
from torch.optim import Optimizer
-from torch.optim.lr_scheduler import ExponentialLR
+from torch.optim.lr_scheduler import ExponentialLR, LRScheduler
from torch.utils.data import DataLoader
from composer import Trainer
-from composer.core import Callback, PyTorchScheduler, State, TimeUnit
+from composer.core import Callback, State, TimeUnit
from composer.loggers.logger import Logger
from composer.optim import MultiStepScheduler
from composer.trainer._scale_schedule import scale_pytorch_scheduler
@@ -33,7 +33,7 @@ def flatten(lst: list):
class TestScaleSchedule():
@staticmethod
- def _test(targets: List[float], scheduler: PyTorchScheduler, epochs: int, optimizer: Optimizer, ssr: float):
+ def _test(targets: List[float], scheduler: LRScheduler, epochs: int, optimizer: Optimizer, ssr: float):
scale_pytorch_scheduler(scheduler, ssr)
for epoch in range(epochs):
for param_group in optimizer.param_groups:
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index 6408c008b6..97ca2005ee 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -14,7 +14,6 @@
import pytest
import torch
-from packaging import version
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
@@ -122,8 +121,6 @@ def test_no_param_model(self, call_fit: bool, call_eval: bool):
if call_eval:
trainer.eval(subset_num_batches=1)
- @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0.0'),
- reason='requires PyTorch 2.0 or higher')
@pytest.mark.parametrize('compile_config', [(None, False), ({}, True), ({'mode': 'reduce-overhead'}, True)])
def test_torch_compile(self, model: ComposerModel, compile_config: Any):
train_dataset = RandomClassificationDataset()
@@ -137,8 +134,6 @@ def test_torch_compile(self, model: ComposerModel, compile_config: Any):
compile_config=compile_config[0])
assert trainer.local_hparams['is_model_compiled'] is compile_config[1]
- @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0.0'),
- reason='requires PyTorch 2.0 or higher')
def test_already_compiled_warning(self, caplog, model: ComposerModel):
with caplog.at_level(logging.WARNING):
train_dataset = RandomClassificationDataset()
@@ -153,20 +148,6 @@ def test_already_compiled_warning(self, caplog, model: ComposerModel):
compile_config=None)
assert '`model` is already compiled with `torch.compile`' in caplog.text
- @pytest.mark.skipif(version.parse(torch.__version__) >= version.parse('2.0.0'),
- reason='requires PyTorch 1.13 or lower')
- def test_compile_unsupported_torch_version_exception(self, caplog, model: ComposerModel):
- with pytest.raises(ValueError, match='`torch.compile` is supported for PyTorch 2.0 or higher.'):
- train_dataset = RandomClassificationDataset()
- optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
- max_duration = '2ba'
- _ = Trainer(model=model,
- max_duration=max_duration,
- train_dataloader=DataLoader(train_dataset, sampler=dist.get_sampler(train_dataset)),
- optimizers=optimizer,
- auto_log_hparams=True,
- compile_config={})
-
def test_eval_metrics(self):
model = SimpleModel()
train_dataloader = DataLoader(RandomClassificationDataset(size=1), batch_size=1)
@@ -340,9 +321,7 @@ def test_max_duration_tokens(self, tiny_bert_tokenizer, batch_size: int, sequenc
@pytest.mark.parametrize('train_subset_num_batches', [-1, 1])
def test_infinite_train_loader(self, model: ComposerModel, max_duration: Union[int, str],
train_subset_num_batches: int):
- should_raise = (isinstance(max_duration, int) or
- max_duration.endswith('ep')) and (train_subset_num_batches is None or
- train_subset_num_batches == -1)
+ should_raise = (isinstance(max_duration, int) or max_duration.endswith('ep')) and train_subset_num_batches == -1
context = pytest.raises(
ValueError,
match='max_duration cannot be specified in epochs') if should_raise else contextlib.nullcontext()
@@ -366,7 +345,7 @@ def test_reset_time(
train_dataloader: DataLoader,
model: ComposerModel,
max_duration: Time[int],
- new_duration: Time,
+ new_duration: Optional[Time],
reset_time: bool,
):
# Train once
@@ -629,8 +608,6 @@ def test_deepspeed(
trainer.fit()
@pytest.mark.gpu
- @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
- reason='requires PyTorch 1.13 or higher')
@pytest.mark.parametrize('precision', [Precision.FP32, Precision.AMP_BF16, Precision.AMP_FP16])
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_fsdp(
@@ -674,8 +651,6 @@ def test_fsdp(
trainer.fit()
@pytest.mark.gpu
- @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0.0'),
- reason='requires PyTorch 2.0 or higher')
@pytest.mark.parametrize('precision', [Precision.AMP_BF16, Precision.AMP_FP16])
@pytest.mark.parametrize('compile_config', [None, {}])
@pytest.mark.filterwarnings('ignore::UserWarning')
@@ -1114,8 +1089,6 @@ def test_training_duration_unit(
assert event_counter_callback.event_to_num_calls[Event.EPOCH_END] == 2
assert event_counter_callback.event_to_num_calls[Event.EPOCH_CHECKPOINT] == 2
- @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.0.0'),
- reason='requires PyTorch 2.0 or higher')
@pytest.mark.parametrize('is_model_compiled', [True, False])
def test_compile_uncompile_model_weights_trainer_fit(
self,
diff --git a/tests/trainer/test_trainer_eval.py b/tests/trainer/test_trainer_eval.py
index cb7f561ca3..83e526a8e0 100644
--- a/tests/trainer/test_trainer_eval.py
+++ b/tests/trainer/test_trainer_eval.py
@@ -592,7 +592,7 @@ def __len__(self) -> Optional[int]:
return None
-@pytest.mark.parametrize('eval_subset_num_batches,success', [[None, False], [-1, False], [1, True]])
+@pytest.mark.parametrize('eval_subset_num_batches,success', [[-1, False], [1, True]])
def test_infinite_eval_dataloader(eval_subset_num_batches, success):
"""Test the `eval_subset_num_batches` is required with infinite dataloader."""
# Construct the trainer
diff --git a/tests/utils/eval_client/test_local_eval_client.py b/tests/utils/eval_client/test_local_eval_client.py
index 8a598608d0..b114096ad3 100644
--- a/tests/utils/eval_client/test_local_eval_client.py
+++ b/tests/utils/eval_client/test_local_eval_client.py
@@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
-from composer.utils import LocalEvalClient
+from composer.utils import LocalEvalClient, dist
from tests.common.markers import world_size
@@ -29,10 +29,11 @@
)
@world_size(1, 2)
def test_local_invoke(code: str, result: str, language: str, world_size: int, tmp_path: str):
- """Test invocation function for LocalEvalClient with code that succeeds, fails compilation, times out, and is incorrect in C, C++, Python, JS.
+ """Test invocation function for LocalEvalClient.
+
+ Code can succeed, fail compilation, time out, or be incorrect in C, C++, Python, JS.
"""
- import os
- os.makedirs(os.path.dirname(tmp_path), exist_ok=True)
+ dist.barrier() # Ensure all processes are ready to run the test as invoke doesn't use dist
eval_client = LocalEvalClient()
input = '(1,)' if language == 'python' else '1'
assert eval_client.invoke([[[{
diff --git a/tests/utils/object_store/test_azure_object_store.py b/tests/utils/object_store/test_azure_object_store.py
new file mode 100644
index 0000000000..949e2149ff
--- /dev/null
+++ b/tests/utils/object_store/test_azure_object_store.py
@@ -0,0 +1,33 @@
+# Copyright 2022 MosaicML Composer authors
+# SPDX-License-Identifier: Apache-2.0
+
+import pytest
+from torch.utils.data import DataLoader
+
+from composer.trainer import Trainer
+from tests.common import RandomClassificationDataset, SimpleModel
+
+
+@pytest.mark.remote
+def test_azure_object_store_integration():
+ model = SimpleModel()
+ train_dataloader = DataLoader(dataset=RandomClassificationDataset())
+ trainer_save = Trainer(
+ model=model,
+ train_dataloader=train_dataloader,
+ save_folder='azure://mosaicml-composer-tests/checkpoints/{run_name}',
+ save_filename='test-model.pt',
+ max_duration='1ba',
+ )
+ run_name = trainer_save.state.run_name
+ trainer_save.fit()
+ trainer_save.close()
+
+ trainer_load = Trainer(
+ model=model,
+ train_dataloader=train_dataloader,
+ load_path=f'azure://mosaicml-composer-tests/checkpoints/{run_name}/test-model.pt',
+ max_duration='2ba',
+ )
+ trainer_load.fit()
+ trainer_load.close()
diff --git a/tests/utils/object_store/test_integration_gs_object_store.py b/tests/utils/object_store/test_integration_gs_object_store.py
deleted file mode 100644
index 1a08bb73ce..0000000000
--- a/tests/utils/object_store/test_integration_gs_object_store.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# Copyright 2022 MosaicML Composer authors
-# SPDX-License-Identifier: Apache-2.0
-
-import time
-from pathlib import Path
-
-import pytest
-
-from composer.utils import GCSObjectStore
-
-__DUMMY_OBJ__ = '/tmp/dummy.ckpt'
-__NUM_BYTES__ = 1000
-bucket_name = 'mosaicml-composer-tests'
-
-
-@pytest.mark.remote
-@pytest.fixture
-def gs_object_store():
- pytest.skip('Run this test suite only after GCS service account is configured on CI node.')
- remote_dir = 'gs://mosaicml-composer-tests/streaming/'
- yield GCSObjectStore(remote_dir)
-
-
-@pytest.mark.remote
-def test_bucket_not_found():
- pytest.skip('Run this test suite only after GCS service account is configured on CI node.')
- with pytest.raises(FileNotFoundError):
- _ = GCSObjectStore('gs://not_a_bucket/streaming')
-
-
-@pytest.mark.remote
-def test_get_uri(gs_object_store):
- pytest.skip('Run this test suite only after GCS service account is configured on CI node.')
- object_name = 'test-object'
- expected_uri = 'gs://mosaicml-composer-tests/streaming/test-object'
- assert (gs_object_store.get_uri(object_name) == expected_uri)
-
-
-@pytest.mark.remote
-def test_get_key(gs_object_store):
- pytest.skip('Run this test suite only after GCS service account is configured on CI node.')
- object_name = 'test-object'
- expected_key = 'streaming/test-object'
- assert (gs_object_store.get_key(object_name) == expected_key)
-
-
-@pytest.mark.remote
-@pytest.mark.parametrize('result', ['success', 'not found'])
-def test_get_object_size(gs_object_store, result: str):
- pytest.skip('Run this test suite only after GCS service account is configured on CI node.')
- fn = Path(__DUMMY_OBJ__)
- with open(fn, 'wb') as fp:
- fp.write(bytes('0' * __NUM_BYTES__, 'utf-8'))
- gs_object_store.upload_object(fn)
-
- if result == 'success':
- assert (gs_object_store.get_object_size(__DUMMY_OBJ__) == __NUM_BYTES__)
- else: # not found
- with pytest.raises(FileNotFoundError):
- gs_object_store.get_object_size(__DUMMY_OBJ__ + f'time.ctime()')
-
-
-@pytest.mark.remote
-def test_upload_object(gs_object_store):
- pytest.skip('Run this test suite only after GCS service account is configured on CI node.')
- from google.cloud.storage import Blob
- destination_blob_name = '/tmp/dummy.ckpt2'
- key = gs_object_store.get_key(destination_blob_name)
- stats = Blob(bucket=gs_object_store.bucket, name=key).exists(gs_object_store.client)
- if not stats:
- gs_object_store.upload_object(__DUMMY_OBJ__, destination_blob_name)
-
-
-@pytest.mark.remote
-def test_list_objects(gs_object_store):
- pytest.skip('Run this test suite only after GCS service account is configured on CI node.')
- from google.cloud.storage import Blob
- destination_blob_name = '/tmp/dummy.ckpt2'
- key = gs_object_store.get_key(destination_blob_name)
- stats = Blob(bucket=gs_object_store.bucket, name=key).exists(gs_object_store.client)
- if not stats:
- gs_object_store.upload_object(__DUMMY_OBJ__, destination_blob_name)
- objects = gs_object_store.list_objects()
- assert (key in objects)
-
-
-@pytest.mark.remote
-@pytest.mark.parametrize('result', ['success', 'file_exists', 'obj_not_found'])
-def test_download_object(gs_object_store, tmp_path, result: str):
- pytest.skip('Run this test suite only after GCS service account is configured on CI node.')
- fn = Path(__DUMMY_OBJ__)
- with open(fn, 'wb') as fp:
- fp.write(bytes('0' * __NUM_BYTES__, 'utf-8'))
- gs_object_store.upload_object(fn)
-
- object_name = __DUMMY_OBJ__
- filename = './dummy.ckpt.download'
-
- if result == 'success':
- gs_object_store.download_object(object_name, filename, overwrite=True)
-
- elif result == 'file_exists':
- with pytest.raises(FileExistsError):
- gs_object_store.download_object(object_name, __DUMMY_OBJ__)
- else: # obj_not_found
- with pytest.raises(FileNotFoundError):
- gs_object_store.download_object(object_name + f'{time.ctime()}', filename, overwrite=True)
diff --git a/tests/utils/object_store/test_mlflow_object_store.py b/tests/utils/object_store/test_mlflow_object_store.py
index d46fc493a4..ecbedd2e50 100644
--- a/tests/utils/object_store/test_mlflow_object_store.py
+++ b/tests/utils/object_store/test_mlflow_object_store.py
@@ -8,7 +8,7 @@
import pytest
from composer.utils import MLFlowObjectStore
-from composer.utils.object_store.mlflow_object_store import PLACEHOLDER_EXPERIMENT_ID, PLACEHOLDER_RUN_ID
+from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_PLACEHOLDER, MLFLOW_RUN_ID_PLACEHOLDER
TEST_PATH_FORMAT = 'databricks/mlflow-tracking/{experiment_id}/{run_id}/artifacts/'
EXPERIMENT_ID = '123'
@@ -66,7 +66,7 @@ def test_init_with_experiment_and_no_run(monkeypatch):
mock_mlflow_client.return_value.create_run.return_value = MagicMock(
info=MagicMock(run_id=RUN_ID, run_name='test-run'))
- store = MLFlowObjectStore(TEST_PATH_FORMAT.format(experiment_id=EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID))
+ store = MLFlowObjectStore(TEST_PATH_FORMAT.format(experiment_id=EXPERIMENT_ID, run_id=MLFLOW_RUN_ID_PLACEHOLDER))
assert store.experiment_id == EXPERIMENT_ID
assert store.run_id == RUN_ID
@@ -76,7 +76,7 @@ def test_init_with_run_and_no_experiment(monkeypatch):
monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock())
with pytest.raises(ValueError):
- MLFlowObjectStore(TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=RUN_ID))
+ MLFlowObjectStore(TEST_PATH_FORMAT.format(experiment_id=MLFLOW_EXPERIMENT_ID_PLACEHOLDER, run_id=RUN_ID))
def test_init_with_active_run(monkeypatch):
@@ -91,7 +91,7 @@ def test_init_with_active_run(monkeypatch):
mock_active_run.return_value = MagicMock(info=MagicMock(experiment_id=EXPERIMENT_ID, run_id=RUN_ID))
store = MLFlowObjectStore(
- TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID))
+ TEST_PATH_FORMAT.format(experiment_id=MLFLOW_EXPERIMENT_ID_PLACEHOLDER, run_id=MLFLOW_RUN_ID_PLACEHOLDER))
assert store.experiment_id == EXPERIMENT_ID
assert store.run_id == RUN_ID
@@ -109,7 +109,7 @@ def test_init_with_existing_experiment_and_no_run(monkeypatch):
info=MagicMock(run_id=RUN_ID, run_name='test-run'))
store = MLFlowObjectStore(
- TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID))
+ TEST_PATH_FORMAT.format(experiment_id=MLFLOW_EXPERIMENT_ID_PLACEHOLDER, run_id=MLFLOW_RUN_ID_PLACEHOLDER))
assert store.experiment_id == EXPERIMENT_ID
assert store.run_id == RUN_ID
@@ -128,7 +128,7 @@ def test_init_with_no_experiment_and_no_run(monkeypatch):
info=MagicMock(run_id=RUN_ID, run_name='test-run'))
store = MLFlowObjectStore(
- TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID))
+ TEST_PATH_FORMAT.format(experiment_id=MLFLOW_EXPERIMENT_ID_PLACEHOLDER, run_id=MLFLOW_RUN_ID_PLACEHOLDER))
assert store.experiment_id == EXPERIMENT_ID
assert store.run_id == RUN_ID
@@ -190,16 +190,19 @@ def test_get_artifact_path(mlflow_object_store):
assert mlflow_object_store.get_artifact_path(DEFAULT_PATH + ARTIFACT_PATH) == ARTIFACT_PATH
# Absolute DBFS path with placeholders
- path = TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID) + ARTIFACT_PATH
+ path = TEST_PATH_FORMAT.format(experiment_id=MLFLOW_EXPERIMENT_ID_PLACEHOLDER,
+ run_id=MLFLOW_RUN_ID_PLACEHOLDER) + ARTIFACT_PATH
assert mlflow_object_store.get_artifact_path(path) == ARTIFACT_PATH
# Raises ValueError for different experiment ID
- path = TEST_PATH_FORMAT.format(experiment_id='different-experiment', run_id=PLACEHOLDER_RUN_ID) + ARTIFACT_PATH
+ path = TEST_PATH_FORMAT.format(experiment_id='different-experiment',
+ run_id=MLFLOW_RUN_ID_PLACEHOLDER) + ARTIFACT_PATH
with pytest.raises(ValueError):
mlflow_object_store.get_artifact_path(path)
# Raises ValueError for different run ID
- path = TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id='different-run') + ARTIFACT_PATH
+ path = TEST_PATH_FORMAT.format(experiment_id=MLFLOW_EXPERIMENT_ID_PLACEHOLDER,
+ run_id='different-run') + ARTIFACT_PATH
with pytest.raises(ValueError):
mlflow_object_store.get_artifact_path(path)
diff --git a/tests/utils/object_store/test_oci_object_store.py b/tests/utils/object_store/test_oci_object_store.py
index 49676cd2e8..728462f3b3 100644
--- a/tests/utils/object_store/test_oci_object_store.py
+++ b/tests/utils/object_store/test_oci_object_store.py
@@ -54,7 +54,7 @@ def test_upload_object(test_oci_obj_store, monkeypatch, tmp_path, mock_bucket_na
bucket_name=mock_bucket_name,
object_name=mock_object_name,
file_path=file_to_upload)
- else: # result = bucket_not_found
+ elif result == 'bucket_not_found':
bucket_not_found_msg = f'Either the bucket named f{mock_bucket_name} does not exist in the namespace*'
mock_upload_file_with_exception = Mock(side_effect=oci.exceptions.ServiceError(
status=404, code='BucketNotFound', headers={'opc-request-id': 'foo'}, message=bucket_not_found_msg))
@@ -68,7 +68,7 @@ def test_upload_object(test_oci_obj_store, monkeypatch, tmp_path, mock_bucket_na
oci_os.upload_object(mock_object_name, filename=file_to_upload)
-@pytest.mark.parametrize('result', ['success', 'file_exists', 'obj_not_found', 'bucket_not_found'])
+@pytest.mark.parametrize('result', ['success', 'file_exists', 'obj_not_found', 'bucket_not_found', 'no_code'])
def test_download_object(test_oci_obj_store, monkeypatch, tmp_path, mock_bucket_name, result: str):
oci = pytest.importorskip('oci')
oci_os = test_oci_obj_store
@@ -112,7 +112,7 @@ def test_download_object(test_oci_obj_store, monkeypatch, tmp_path, mock_bucket_
FileNotFoundError,
match=f'Object oci://{mock_bucket_name}/{mock_object_name} not found. {obj_not_found_msg}'):
oci_os.download_object(mock_object_name, filename=file_to_download_to)
- else: #result == 'bucket_not_found':
+ elif result == 'bucket_not_found':
file_to_download_to = str(tmp_path / Path('my_bucket_not_found_file.bin'))
bucket_not_found_msg = f'Either the bucket named f{mock_bucket_name} does not exist in the namespace*'
mock_get_object_fn_with_exception = Mock(side_effect=oci.exceptions.ServiceError(
@@ -125,6 +125,19 @@ def test_download_object(test_oci_obj_store, monkeypatch, tmp_path, mock_bucket_
f'Bucket specified in oci://{mock_bucket_name}/{mock_object_name} not found. {bucket_not_found_msg}'
):
oci_os.download_object(mock_object_name, filename=file_to_download_to)
+ elif result == 'no_code':
+ file_to_download_to = str(tmp_path / Path('my_bucket_not_found_file.bin'))
+ bucket_not_found_msg = f'Either the bucket named f{mock_bucket_name} does not exist in the namespace*'
+ mock_get_object_fn_with_exception = Mock(side_effect=oci.exceptions.ServiceError(
+ status=404, code=None, headers={'opc-request-id': 'foo'}, message=bucket_not_found_msg))
+ with monkeypatch.context() as m:
+ m.setattr(oci_os.client, 'get_object', mock_get_object_fn_with_exception)
+ with pytest.raises(
+ FileNotFoundError,
+ match=
+ f'Object oci://{mock_bucket_name}/{mock_object_name} not found with no error code. {bucket_not_found_msg}'
+ ):
+ oci_os.download_object(mock_object_name, filename=file_to_download_to)
@pytest.mark.parametrize('result', ['success', 'bucket_not_found'])
@@ -171,7 +184,7 @@ def __init__(self, name: str, size: int):
oci_os.list_objects(prefix=prefix)
-@pytest.mark.parametrize('result', ['success', 'obj_not_found', 'bucket_not_found'])
+@pytest.mark.parametrize('result', ['success', 'obj_not_found', 'bucket_not_found', 'no_code'])
def test_get_object_size(test_oci_obj_store, mock_bucket_name, monkeypatch, result: str):
oci = pytest.importorskip('oci')
oci_os = test_oci_obj_store
@@ -186,7 +199,6 @@ def test_get_object_size(test_oci_obj_store, mock_bucket_name, monkeypatch, resu
with monkeypatch.context() as m:
m.setattr(oci_os.client, 'get_object', mock_get_object_fn)
assert oci_os.get_object_size(mock_object_name) == mock_object_size
-
elif result == 'obj_not_found':
obj_not_found_msg = f"The object '{mock_object_name}' was not found in the bucket f'{mock_bucket_name}'"
mock_get_object_fn_with_exception = Mock(side_effect=oci.exceptions.ServiceError(
@@ -197,8 +209,7 @@ def test_get_object_size(test_oci_obj_store, mock_bucket_name, monkeypatch, resu
FileNotFoundError,
match=f'Object oci://{mock_bucket_name}/{mock_object_name} not found. {obj_not_found_msg}'):
oci_os.get_object_size(mock_object_name)
-
- else: #result == 'bucket_not_found':
+ elif result == 'bucket_not_found':
bucket_not_found_msg = f'Either the bucket named f{mock_bucket_name} does not exist in the namespace*'
mock_get_object_fn_with_exception = Mock(side_effect=oci.exceptions.ServiceError(
status=404, code='BucketNotFound', headers={'opc-request-id': 'foo'}, message=bucket_not_found_msg))
@@ -210,3 +221,15 @@ def test_get_object_size(test_oci_obj_store, mock_bucket_name, monkeypatch, resu
f'Bucket specified in oci://{mock_bucket_name}/{mock_object_name} not found. {bucket_not_found_msg}'
):
oci_os.get_object_size(mock_object_name)
+ elif result == 'bucket_not_found':
+ bucket_not_found_msg = f'Either the bucket named f{mock_bucket_name} does not exist in the namespace*'
+ mock_get_object_fn_with_exception = Mock(side_effect=oci.exceptions.ServiceError(
+ status=404, code=None, headers={'opc-request-id': 'foo'}, message=bucket_not_found_msg))
+ with monkeypatch.context() as m:
+ m.setattr(oci_os.client, 'get_object', mock_get_object_fn_with_exception)
+ with pytest.raises(
+ ValueError,
+ match=
+ f'Bucket specified in oci://{mock_bucket_name}/{mock_object_name} not found. {bucket_not_found_msg}'
+ ):
+ oci_os.get_object_size(mock_object_name)
diff --git a/tests/utils/object_store/test_s3_object_store.py b/tests/utils/object_store/test_s3_object_store.py
index eb6b8a0c72..2d7033c5be 100644
--- a/tests/utils/object_store/test_s3_object_store.py
+++ b/tests/utils/object_store/test_s3_object_store.py
@@ -41,7 +41,7 @@ def test_s3_upload_object_arguments(tmp_path: pathlib.Path, s3_bucket: str):
remote_obj_name = 'remote.txt'
object_store = S3ObjectStore(bucket=s3_bucket)
- object_store.client.upload_file = MagicMock()
+ object_store.client.upload_file = MagicMock() # pyright: ignore[reportGeneralTypeIssues]
with mock.patch.dict('os.environ'):
os.environ.pop('S3_CANNED_ACL', None)
diff --git a/tests/utils/object_store/test_uc_object_store.py b/tests/utils/object_store/test_uc_object_store.py
index 1f84143186..60845e43eb 100644
--- a/tests/utils/object_store/test_uc_object_store.py
+++ b/tests/utils/object_store/test_uc_object_store.py
@@ -78,19 +78,26 @@ def test_uc_object_store_invalid_prefix(monkeypatch):
@pytest.mark.parametrize('result', ['success', 'not_found'])
def test_get_object_size(ws_client, uc_object_store, result: str):
if result == 'success':
- db_files = pytest.importorskip('databricks.sdk.service.files')
- ws_client.files.get_status.return_value = db_files.FileInfo(file_size=100)
- assert uc_object_store.get_object_size('train.txt') == 100
+ ws_client.api_client.do.return_value = {}
+ assert uc_object_store.get_object_size('train.txt') == 1000000
elif result == 'not_found':
db_core = pytest.importorskip('databricks.sdk.core', reason='requires databricks')
- ws_client.files.get_status.side_effect = db_core.DatabricksError('The file being accessed is not found',
- error_code='NOT_FOUND')
+ ws_client.api_client.do.side_effect = db_core.DatabricksError('The file being accessed is not found',
+ error_code='NOT_FOUND')
with pytest.raises(FileNotFoundError):
uc_object_store.get_object_size('train.txt')
else:
raise NotImplementedError(f'Test for result={result} is not implemented.')
+def test_get_object_size_full_path(ws_client, uc_object_store):
+ ws_client.api_client.do.return_value = {}
+ assert uc_object_store.get_object_size('Volumes/catalog/schema/volume/train.txt') == 1000000
+ ws_client.api_client.do.assert_called_with(method='HEAD',
+ path=f'/api/2.0/fs/files/Volumes/catalog/schema/volume/train.txt',
+ headers={'Source': 'mosaicml/composer'})
+
+
def test_get_uri(uc_object_store):
assert uc_object_store.get_uri('train.txt') == 'dbfs:/Volumes/catalog/schema/volume/train.txt'
assert uc_object_store.get_uri('Volumes/catalog/schema/volume/checkpoint/model.bin'
@@ -160,6 +167,49 @@ def generate_dummy_file(_):
raise NotImplementedError(f'Test for result={result} is not implemented.')
+def test_list_objects_nested_folders(ws_client, uc_object_store):
+ expected_files = [
+ '/Volumes/catalog/volume/schema/path/to/folder/file1.txt',
+ '/Volumes/catalog/volume/schema/path/to/folder/file2.txt',
+ '/Volumes/catalog/volume/schema/path/to/folder/subdir/file1.txt',
+ '/Volumes/catalog/volume/schema/path/to/folder/subdir/file2.txt',
+ ]
+ uc_list_api_responses = [{
+ 'files': [{
+ 'path': '/Volumes/catalog/volume/schema/path/to/folder/file1.txt',
+ 'is_dir': False
+ }, {
+ 'path': '/Volumes/catalog/volume/schema/path/to/folder/file2.txt',
+ 'is_dir': False
+ }, {
+ 'path': '/Volumes/catalog/volume/schema/path/to/folder/subdir',
+ 'is_dir': True
+ }]
+ }, {
+ 'files': [{
+ 'path': '/Volumes/catalog/volume/schema/path/to/folder/subdir/file1.txt',
+ 'is_dir': False
+ }, {
+ 'path': '/Volumes/catalog/volume/schema/path/to/folder/subdir/file2.txt',
+ 'is_dir': False
+ }]
+ }]
+
+ prefix = 'Volumes/catalog/schema/volume/path/to/folder'
+
+ ws_client.api_client.do = MagicMock(side_effect=[uc_list_api_responses[0], uc_list_api_responses[1]])
+ actual_files = uc_object_store.list_objects(prefix=prefix)
+
+ assert actual_files == expected_files
+
+ ws_client.api_client.do.assert_called_with(method='GET',
+ path=uc_object_store._UC_VOLUME_LIST_API_ENDPOINT,
+ data='{"path": "/Volumes/catalog/volume/schema/path/to/folder/subdir"}',
+ headers={'Source': 'mosaicml/composer'})
+
+ assert ws_client.api_client.do.call_count == 2
+
+
@pytest.mark.parametrize('result', ['success', 'prefix_none', 'not_found', 'error'])
def test_list_objects(ws_client, uc_object_store, result):
expected_files = [
@@ -173,9 +223,6 @@ def test_list_objects(ws_client, uc_object_store, result):
}, {
'path': '/Volumes/catalog/volume/schema/path/to/folder/file2.txt',
'is_dir': False
- }, {
- 'path': '/Volumes/catalog/volume/schema/path/to/folder/samples/',
- 'is_dir': True
}]
}
diff --git a/tests/utils/test_autolog_hparams.py b/tests/utils/test_autolog_hparams.py
index 7804d7bd80..4452f7bb65 100644
--- a/tests/utils/test_autolog_hparams.py
+++ b/tests/utils/test_autolog_hparams.py
@@ -10,7 +10,7 @@
from composer.loggers import InMemoryLogger
from composer.trainer import Trainer
from composer.utils import (StringEnum, convert_flat_dict_to_nested_dict, convert_nested_dict_to_flat_dict,
- extract_hparams, using_torch_2)
+ extract_hparams)
from tests.common.datasets import RandomClassificationDataset
from tests.common.models import SimpleModel
@@ -146,7 +146,6 @@ def test_extract_hparams_trainer():
# Compile
'compile_config': None,
'is_model_compiled': False,
- 'is_torch_2_0': using_torch_2(),
# Load Checkpoint
'load_path': None,
@@ -164,6 +163,7 @@ def test_extract_hparams_trainer():
'save_overwrite': False,
'save_interval': '1ep',
'save_weights_only': False,
+ 'save_ignore_keys': None,
'save_num_checkpoints_to_keep': -1,
'save_metrics': False,
diff --git a/tests/utils/test_file_helpers.py b/tests/utils/test_file_helpers.py
index 2e757afbe4..7c4e470547 100644
--- a/tests/utils/test_file_helpers.py
+++ b/tests/utils/test_file_helpers.py
@@ -213,17 +213,6 @@ def test_safe_format_name_with_dist(monkeypatch: pytest.MonkeyPatch, world_size)
assert format_name_with_dist(format_str, 'awesome_run') == expected_str
-@world_size(2)
-def test_unsafe_format_name_with_dist(monkeypatch: pytest.MonkeyPatch, world_size):
- """Node rank is deleted, but also in the format string, so expect error."""
- vars = ['run_name', 'node_rank']
- format_str = ','.join(f'{x}={{{x}}}' for x in vars)
-
- monkeypatch.delenv('NODE_RANK')
- with pytest.raises(KeyError):
- assert format_name_with_dist(format_str, 'awesome_run') == 'run_name=awesome_run,node_rank=3'
-
-
def test_format_name_with_dist_and_time():
vars = [
'run_name',
@@ -341,7 +330,7 @@ def test_maybe_create_remote_uploader_downloader_from_uri(monkeypatch):
mock_remote_ud = MagicMock()
m.setattr(loggers, 'RemoteUploaderDownloader', mock_remote_ud)
maybe_create_remote_uploader_downloader_from_uri('gs://my-nifty-gs-bucket/path/to/checkpoints.pt', loggers=[])
- mock_remote_ud.assert_called_once_with(bucket_uri='gs://my-nifty-gs-bucket'),
+ mock_remote_ud.assert_called_once_with(bucket_uri='gs://my-nifty-gs-bucket')
with pytest.raises(NotImplementedError):
maybe_create_remote_uploader_downloader_from_uri('wandb://my-cool/checkpoint/for/my/model.pt', loggers=[])
@@ -357,7 +346,9 @@ def test_maybe_create_remote_uploader_downloader_from_uri(monkeypatch):
backend_kwargs={'path': 'Volumes/checkpoint/for/my/model.pt'})
with pytest.raises(ValueError):
- maybe_create_remote_uploader_downloader_from_uri('dbfs:/checkpoint/for/my/model.pt', loggers=[])
+ rud = maybe_create_remote_uploader_downloader_from_uri('dbfs:/checkpoint/for/my/model.pt', loggers=[])
+ assert rud is not None
+ _ = rud.remote_backend
def test_ensure_folder_is_empty(tmp_path: pathlib.Path):
diff --git a/tests/utils/test_inference.py b/tests/utils/test_inference.py
index e43aa184b7..789ad3c136 100644
--- a/tests/utils/test_inference.py
+++ b/tests/utils/test_inference.py
@@ -20,7 +20,6 @@
from composer.functional import apply_gated_linear_units
from composer.loggers import InMemoryLogger, Logger
from composer.loggers.logger_destination import LoggerDestination
-from composer.models import composer_resnet
from composer.trainer.dist_strategy import prepare_ddp_module
from composer.trainer.trainer import Trainer
from composer.utils import dist, export_with_logger, inference
@@ -28,7 +27,7 @@
from tests.common import SimpleTransformerClassifier, device
from tests.common.datasets import (RandomImageDataset, dummy_text_classification_dataloader, dummy_tiny_bert_lm_batch,
dummy_transformer_classifier_batch)
-from tests.common.models import configure_tiny_bert_hf_model
+from tests.common.models import composer_resnet, configure_tiny_bert_hf_model
class MockFileUploader(LoggerDestination):
@@ -212,11 +211,9 @@ def test_export_for_inference_onnx(model_cls, sample_input, onnx_opset_version,
if onnx_opset_version == None and version.parse(torch.__version__) < version.parse('1.13'):
pytest.skip("Don't test prior PyTorch version's default Opset version.")
- from composer.utils.misc import using_torch_2
- if using_torch_2():
- pytest.xfail(
- 'torch.onnx.errors.UnsupportedOperatorError: Exporting the operator "aten::unflatten" to ONNX opset version 14 is not supported.'
- )
+ pytest.xfail(
+ 'torch.onnx.errors.UnsupportedOperatorError: Exporting the operator "aten::unflatten" to ONNX opset version 14 is not supported.'
+ )
import onnx
import onnx.checker
@@ -328,11 +325,9 @@ def test_export_for_inference_onnx_ddp(model_cls, sample_input, onnx_opset_versi
pytest.importorskip('onnx')
pytest.importorskip('onnxruntime')
- from composer.utils.misc import using_torch_2
- if using_torch_2():
- pytest.xfail(
- 'torch.onnx.errors.UnsupportedOperatorError: Exporting the operator "aten::unflatten" to ONNX opset version 14 is not supported.'
- )
+ pytest.xfail(
+ 'torch.onnx.errors.UnsupportedOperatorError: Exporting the operator "aten::unflatten" to ONNX opset version 14 is not supported.'
+ )
if onnx_opset_version == None and version.parse(torch.__version__) < version.parse('1.13'):
pytest.skip("Don't test prior PyTorch version's default Opset version.")
diff --git a/tests/utils/test_misc.py b/tests/utils/test_misc.py
new file mode 100644
index 0000000000..333262795d
--- /dev/null
+++ b/tests/utils/test_misc.py
@@ -0,0 +1,22 @@
+# Copyright 2022 MosaicML Composer authors
+# SPDX-License-Identifier: Apache-2.0
+
+from composer.utils.misc import partial_format
+
+
+def test_partial_format():
+ # No args provided
+ assert partial_format('{foo} {bar} {}') == '{foo} {bar} {}'
+
+ # Keyword args
+ assert partial_format('{foo} {bar}', foo='Hello') == 'Hello {bar}'
+ assert partial_format('{foo} {bar}', foo='Hello', bar='World') == 'Hello World'
+
+ # Positional args
+ assert partial_format('{} {}', 'Hello') == 'Hello {}'
+ assert partial_format('{} {}', 'Hello', 'World') == 'Hello World'
+
+ # Positional and keyword args
+ assert partial_format('{foo} {}', 'World') == '{foo} World'
+ assert partial_format('{foo} {}', foo='Hello') == 'Hello {}'
+ assert partial_format('{foo} {}', 'World', foo='Hello') == 'Hello World'